Unverified Commit 6fc940ed authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add mBART-50 (#10154)

* add tokenizer for mBART-50

* update tokenizers

* make src_lang and tgt_lang optional

* update tokenizer test

* add setter

* update docs

* update conversion script

* update docs

* update conversion script

* update tokenizer

* update test

* update docs

* doc

* address Sylvain's suggestions

* fix test

* fix formatting

* nits
parent 57021887
...@@ -217,6 +217,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. ...@@ -217,6 +217,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[MBart-50](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[MPNet](https://huggingface.co/transformers/model_doc/mpnet.html)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MPNet](https://huggingface.co/transformers/model_doc/mpnet.html)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/transformers/model_doc/mt5.html)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[MT5](https://huggingface.co/transformers/model_doc/mt5.html)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777)> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777)> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
......
...@@ -161,48 +161,51 @@ and conversion utilities for the following models: ...@@ -161,48 +161,51 @@ and conversion utilities for the following models:
26. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for 26. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
27. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted 27. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
28. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin, Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
Jianfeng Lu, Tie-Yan Liu. Jianfeng Lu, Tie-Yan Liu.
28. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained 29. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
29. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted 30. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao, Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
Mohammad Saleh and Peter J. Liu. Mohammad Saleh and Peter J. Liu.
30. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting 31. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi, Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
31. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient 32. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
32. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT 33. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
33. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP 34. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
Krishna, and Kurt W. Keutzer. Krishna, and Kurt W. Keutzer.
34. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a 35. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
35. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via 36. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos. Francesco Piccinno and Julian Martin Eisenschlos.
36. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL: 37. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*, Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
37. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for 38. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli. Zhou, Abdelrahman Mohamed, Michael Auli.
38. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model 39. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau. Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
39. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet: 40. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
40. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised 41. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov. Zettlemoyer and Veselin Stoyanov.
41. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive 42. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
......
...@@ -10,14 +10,14 @@ ...@@ -10,14 +10,14 @@
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. specific language governing permissions and limitations under the License.
MBart MBart and MBart-50
----------------------------------------------------------------------------------------------------------------------- -----------------------------------------------------------------------------------------------------------------------
**DISCLAIMER:** If you see something strange, file a `Github Issue **DISCLAIMER:** If you see something strange, file a `Github Issue
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@patrickvonplaten @patrickvonplaten
Overview Overview of MBart
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation
...@@ -31,17 +31,9 @@ on the encoder, decoder, or reconstructing parts of the text. ...@@ -31,17 +31,9 @@ on the encoder, decoder, or reconstructing parts of the text.
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`__ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`__
Examples Training of MBart
_______________________________________________________________________________________________________________________ _______________________________________________________________________________________________________________________
- Examples and scripts for fine-tuning mBART and other models for sequence to sequence tasks can be found in
:prefix_link:`examples/seq2seq/ <examples/seq2seq/README.md>`.
- Given the large embeddings table, mBART consumes a large amount of GPU RAM, especially for fine-tuning.
:class:`MarianMTModel` is usually a better choice for bilingual machine translation.
Training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is
multilingual it expects the sequences in a different format. A special language id token is added in both the source multilingual it expects the sequences in a different format. A special language id token is added in both the source
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target
...@@ -76,6 +68,87 @@ the sequences for sequence-to-sequence fine-tuning. ...@@ -76,6 +68,87 @@ the sequences for sequence-to-sequence fine-tuning.
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria" assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
Overview of MBart-50
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
MBart-50 was introduced in the `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning
<https://arxiv.org/abs/2008.00401>` paper by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav
Chaudhary, Jiatao Gu, Angela Fan. MBart-50 is created using the original `mbart-large-cc25` checkpoint by extendeding
its embedding layers with randomly initialized vectors for an extra set of 25 language tokens and then pretrained on 50
languages.
According to the abstract
*Multilingual translation models can be created through multilingual finetuning. Instead of finetuning on one
direction, a pretrained model is finetuned on many directions at the same time. It demonstrates that pretrained models
can be extended to incorporate additional languages without loss of performance. Multilingual finetuning improves on
average 1 BLEU over the strongest baselines (being either multilingual from scratch or bilingual finetuning) while
improving 9.3 BLEU on average over bilingual baselines from scratch.*
Training of MBart-50
_______________________________________________________________________________________________________________________
The text format for MBart-50 is slightly different from mBART. For MBart-50 the language id token is used as a prefix
for both source and target text i.e the text format is :obj:`[lang_code] X [eos]`, where :obj:`lang_code` is source
language id for source text and target language id for target text, with :obj:`X` being the source or target text
respectively.
MBart-50 has its own tokenizer :class:`~transformers.MBart50Tokenizer`.
- Supervised training
.. code-block::
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
src_text = " UN Chief Says There Is No Military Solution in Syria"
tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
model_inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
model(**model_inputs, labels=labels) # forward pass
- Generation
To generate using the mBART-50 multilingual translation models, :obj:`eos_token_id` is used as the
:obj:`decoder_start_token_id` and the target language id is forced as the first generated token. To force the
target language id as the first generated token, pass the `forced_bos_token_id` parameter to the `generate` method.
The following example shows how to translate between Hindi to French and Arabic to English using the
`facebook/mbart-50-large-many-to-many` checkpoint.
.. code-block::
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# translate Hindi to French
tokenizer.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."
# translate Arabic to English
tokenizer.src_lang = "ar_AR"
encoded_ar = tokenizer(article_ar, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "The Secretary-General of the United Nations says there is no military solution in Syria."
MBartConfig MBartConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -97,6 +170,20 @@ MBartTokenizerFast ...@@ -97,6 +170,20 @@ MBartTokenizerFast
:members: :members:
MBart50Tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBart50Tokenizer
:members:
MBart50TokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBart50TokenizerFast
:members:
MBartModel MBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -381,6 +381,15 @@ For the full list, refer to `https://huggingface.co/models <https://huggingface. ...@@ -381,6 +381,15 @@ For the full list, refer to `https://huggingface.co/models <https://huggingface.
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``facebook/mbart-large-en-ro`` | | 24-layer, 1024-hidden, 16-heads, 610M parameters | | | ``facebook/mbart-large-en-ro`` | | 24-layer, 1024-hidden, 16-heads, 610M parameters |
| | | | mbart-large-cc25 model finetuned on WMT english romanian translation. | | | | | mbart-large-cc25 model finetuned on WMT english romanian translation. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``facebook/mbart-large-50`` | | 24-layer, 1024-hidden, 16-heads, |
| | | | mBART model trained on 50 languages' monolingual corpus. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``facebook/mbart-large-50-one-to-many-mmt`` | | 24-layer, 1024-hidden, 16-heads, |
| | | | mbart-50-large model finetuned for one (English) to many multilingual machine translation covering 50 languages. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``facebook/mbart-large-50-many-to-many-mmt`` | | 24-layer, 1024-hidden, 16-heads, |
| | | | mbart-50-large model finetuned for many to many multilingual machine translation covering 50 languages. |
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Lxmert | ``lxmert-base-uncased`` | | 9-language layers, 9-relationship layers, and 12-cross-modality layers | | Lxmert | ``lxmert-base-uncased`` | | 9-language layers, 9-relationship layers, and 12-cross-modality layers |
| | | | 768-hidden, 12-heads (for each layer) ~ 228M parameters | | | | | 768-hidden, 12-heads (for each layer) ~ 228M parameters |
......
...@@ -260,6 +260,7 @@ if is_sentencepiece_available(): ...@@ -260,6 +260,7 @@ if is_sentencepiece_available():
_import_structure["models.camembert"].append("CamembertTokenizer") _import_structure["models.camembert"].append("CamembertTokenizer")
_import_structure["models.marian"].append("MarianTokenizer") _import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer") _import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer")
...@@ -296,6 +297,7 @@ if is_tokenizers_available(): ...@@ -296,6 +297,7 @@ if is_tokenizers_available():
_import_structure["models.longformer"].append("LongformerTokenizerFast") _import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast") _import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast") _import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart"].append("MBart50TokenizerFast")
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast") _import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast") _import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast") _import_structure["models.mt5"].append("MT5TokenizerFast")
...@@ -1391,7 +1393,7 @@ if TYPE_CHECKING: ...@@ -1391,7 +1393,7 @@ if TYPE_CHECKING:
from .models.bert_generation import BertGenerationTokenizer from .models.bert_generation import BertGenerationTokenizer
from .models.camembert import CamembertTokenizer from .models.camembert import CamembertTokenizer
from .models.marian import MarianTokenizer from .models.marian import MarianTokenizer
from .models.mbart import MBartTokenizer from .models.mbart import MBart50Tokenizer, MBartTokenizer
from .models.mt5 import MT5Tokenizer from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer from .models.reformer import ReformerTokenizer
...@@ -1419,7 +1421,7 @@ if TYPE_CHECKING: ...@@ -1419,7 +1421,7 @@ if TYPE_CHECKING:
from .models.led import LEDTokenizerFast from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBartTokenizerFast from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
from .models.mobilebert import MobileBertTokenizerFast from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast from .models.mt5 import MT5TokenizerFast
......
...@@ -500,6 +500,35 @@ class MBartConverter(SpmConverter): ...@@ -500,6 +500,35 @@ class MBartConverter(SpmConverter):
) )
class MBart50Converter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
# fmt: off
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
# fmt: on
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="en_XX $A </s>",
pair="en_XX $A $B </s>",
special_tokens=[
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class XLMRobertaConverter(SpmConverter): class XLMRobertaConverter(SpmConverter):
def vocab(self, proto): def vocab(self, proto):
vocab = [ vocab = [
...@@ -637,6 +666,7 @@ SLOW_TO_FAST_CONVERTERS = { ...@@ -637,6 +666,7 @@ SLOW_TO_FAST_CONVERTERS = {
"LEDTokenizer": RobertaConverter, "LEDTokenizer": RobertaConverter,
"LxmertTokenizer": BertConverter, "LxmertTokenizer": BertConverter,
"MBartTokenizer": MBartConverter, "MBartTokenizer": MBartConverter,
"MBart50Tokenizer": MBart50Converter,
"MPNetTokenizer": MPNetConverter, "MPNetTokenizer": MPNetConverter,
"MobileBertTokenizer": BertConverter, "MobileBertTokenizer": BertConverter,
"OpenAIGPTTokenizer": OpenAIGPTConverter, "OpenAIGPTTokenizer": OpenAIGPTConverter,
......
...@@ -32,9 +32,11 @@ _import_structure = { ...@@ -32,9 +32,11 @@ _import_structure = {
if is_sentencepiece_available(): if is_sentencepiece_available():
_import_structure["tokenization_mbart"] = ["MBartTokenizer"] _import_structure["tokenization_mbart"] = ["MBartTokenizer"]
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available(): if is_tokenizers_available():
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"] _import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_mbart"] = [ _import_structure["modeling_mbart"] = [
...@@ -56,8 +58,10 @@ if TYPE_CHECKING: ...@@ -56,8 +58,10 @@ if TYPE_CHECKING:
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer from .tokenization_mbart import MBartTokenizer
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
from .tokenization_mbart_fast import MBartTokenizerFast from .tokenization_mbart_fast import MBartTokenizerFast
if is_torch_available(): if is_torch_available():
......
...@@ -15,19 +15,49 @@ ...@@ -15,19 +15,49 @@
import argparse import argparse
import torch import torch
from torch import nn
from transformers import BartForConditionalGeneration, MBartConfig from transformers import MBartConfig, MBartForConditionalGeneration
from transformers.models.bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"): def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
"decoder.output_projection.weight",
]
for k in ignore_keys:
state_dict.pop(k, None)
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def convert_fairseq_mbart_checkpoint_from_disk(
checkpoint_path, hf_config_path="facebook/mbart-large-en-ro", finetuned=False, mbart_50=False
):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict) remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size) mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
if mbart_50 and finetuned:
mbart_config.activation_function = "relu"
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartForConditionalGeneration(mbart_config) model = MBartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict) model.model.load_state_dict(state_dict)
if finetuned:
model.lm_head = make_linear_from_emb(model.model.shared)
return model return model
...@@ -42,8 +72,12 @@ if __name__ == "__main__": ...@@ -42,8 +72,12 @@ if __name__ == "__main__":
"--hf_config", "--hf_config",
default="facebook/mbart-large-cc25", default="facebook/mbart-large-cc25",
type=str, type=str,
help="Which huggingface architecture to use: bart-large-xsum", help="Which huggingface architecture to use: mbart-large",
) )
parser.add_argument("--mbart_50", action="store_true", help="whether the model is mMART-50 checkpoint")
parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")
args = parser.parse_args() args = parser.parse_args()
model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config) model = convert_fairseq_mbart_checkpoint_from_disk(
args.fairseq_path, hf_config_path=args.hf_config, finetuned=args.finetuned, mbart_50=args.mbart_50
)
model.save_pretrained(args.pytorch_dump_folder_path) model.save_pretrained(args.pytorch_dump_folder_path)
# coding=utf-8
# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from shutil import copyfile
from typing import Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
_all_mbart50_models = ["facebook/mbart-large-50-one-to-many-mmt"]
SPM_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
# fmt: off
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"]
# fmt: on
class MBart50Tokenizer(PreTrainedTokenizer):
"""
Construct a MBart50 tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
src_lang (:obj:`str`, `optional`):
A string representing the source language.
tgt_lang (:obj:`str`, `optional`):
A string representing the target language.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sequence token.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
Examples::
>>> from transformers import MBart50Tokenizer
>>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
>>> # model(**model_inputs, labels=labels) should work
"""
vocab_files_names = VOCAB_FILES_NAMES
max_model_input_sizes = {m: 1024 for m in _all_mbart50_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart50_models}}
model_input_names = ["input_ids", "attention_mask"]
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(
self,
vocab_file,
src_lang=None,
tgt_lang=None,
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__(
src_lang=src_lang,
tgt_lang=tgt_lang,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# Mimic fairseq token-to-id alignment for the first 4 token
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1
self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
}
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property
def vocab_size(self) -> int:
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
@property
def src_lang(self) -> str:
return self._src_lang
@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)
def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d: Dict) -> None:
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def get_vocab(self) -> Dict:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) in an id using the vocab. """
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
# Need to return unknown token if the SP model returned 0
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1] * len(self.suffix_tokens)
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An MBART-50 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``[src_lang_code] X [eos]``
- ``labels``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def prepare_seq2seq_batch(
self,
src_texts: List[str],
src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
**kwargs,
) -> BatchEncoding:
self.src_lang = src_lang
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
self.cur_lang_code_id = self.lang_code_to_id[src_lang]
self.prefix_tokens = [self.cur_lang_code_id]
self.suffix_tokens = [self.eos_token_id]
def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
"""Reset the special tokens to the target language setting. prefix=[tgt_lang_code] and suffix=[eos]."""
self.cur_lang_code_id = self.lang_code_to_id[tgt_lang]
self.prefix_tokens = [self.cur_lang_code_id]
self.suffix_tokens = [self.eos_token_id]
# coding=utf-8
# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from shutil import copyfile
from typing import List, Optional, Tuple
from tokenizers import processors
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
if is_sentencepiece_available():
from .tokenization_mbart50 import MBart50Tokenizer
else:
MBart50Tokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
_all_mbart50_models = ["facebook/mbart-large-50-one-to-many-mmt"]
SPM_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
tokenizer_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json"
# fmt: off
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"]
# fmt: on
class MBart50TokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" MBART tokenizer for mBART-50 (backed by HuggingFace's `tokenizers` library). Based on `BPE
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
methods. Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
src_lang (:obj:`str`, `optional`):
A string representing the source language.
tgt_lang (:obj:`str`, `optional`):
A string representing the target language.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sequence token.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
Examples::
>>> from transformers import MBart50TokenizerFast
>>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
>>> # model(**model_inputs, labels=labels) should work
"""
vocab_files_names = VOCAB_FILES_NAMES
max_model_input_sizes = {m: 1024 for m in _all_mbart50_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart50_models}}
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = MBart50Tokenizer
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(
self,
vocab_file,
src_lang=None,
tgt_lang=None,
tokenizer_file=None,
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__(
vocab_file,
src_lang=src_lang,
tgt_lang=tgt_lang,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)
self.vocab_file = vocab_file
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
self.lang_code_to_id = {
lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
}
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.tgt_lang = tgt_lang
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
self.set_src_lang_special_tokens(self._src_lang)
@property
def src_lang(self) -> str:
return self._src_lang
@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1] * len(self.suffix_tokens)
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. The special tokens depend on calling set_lang.
An MBART-50 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``[src_lang_code] X [eos]``
- ``labels``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def prepare_seq2seq_batch(
self,
src_texts: List[str],
src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
**kwargs,
) -> BatchEncoding:
self.src_lang = src_lang
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
self.cur_lang_code_id = self.convert_tokens_to_ids(src_lang)
self.prefix_tokens = [self.cur_lang_code_id]
self.suffix_tokens = [self.eos_token_id]
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
self._tokenizer.post_processor = processors.TemplateProcessing(
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
)
def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
"""Reset the special tokens to the target language setting. prefix=[src_lang_code] and suffix=[eos]."""
self.cur_lang_code_id = self.convert_tokens_to_ids(tgt_lang)
self.prefix_tokens = [self.cur_lang_code_id]
self.suffix_tokens = [self.eos_token_id]
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
self._tokenizer.post_processor = processors.TemplateProcessing(
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
...@@ -47,6 +47,15 @@ class MarianTokenizer: ...@@ -47,6 +47,15 @@ class MarianTokenizer:
requires_sentencepiece(self) requires_sentencepiece(self)
class MBart50Tokenizer:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_sentencepiece(self)
class MBartTokenizer: class MBartTokenizer:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_sentencepiece(self) requires_sentencepiece(self)
......
...@@ -164,6 +164,15 @@ class LxmertTokenizerFast: ...@@ -164,6 +164,15 @@ class LxmertTokenizerFast:
requires_tokenizers(self) requires_tokenizers(self)
class MBart50TokenizerFast:
def __init__(self, *args, **kwargs):
requires_tokenizers(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tokenizers(self)
class MBartTokenizerFast: class MBartTokenizerFast:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tokenizers(self) requires_tokenizers(self)
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
from transformers.file_utils import is_sentencepiece_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
if is_sentencepiece_available():
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
if is_torch_available():
from transformers.models.mbart.modeling_mbart import shift_tokens_right
EN_CODE = 250004
RO_CODE = 250020
@require_sentencepiece
@require_tokenizers
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MBart50Tokenizer
rust_tokenizer_class = MBart50TokenizerFast
test_rust_tokenizer = True
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def test_full_tokenizer(self):
tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
# fmt: off
[SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", "."],
# fmt: on
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[
value + tokenizer.fairseq_offset
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
# fmt: off
[SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "<unk>", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "<unk>", "."],
# fmt: on
)
@require_torch
@require_sentencepiece
@require_tokenizers
class MBartOneToManyIntegrationTest(unittest.TestCase):
checkpoint_name = "facebook/mbart-large-50-one-to-many-mmt"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
]
tgt_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
]
expected_src_tokens = [EN_CODE, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
@classmethod
def setUpClass(cls):
cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained(
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
)
cls.pad_token_id = 1
return cls
def check_language_codes(self):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["mr_IN"], 250038)
def test_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
self.assertEqual(result, expected_romanian)
self.assertNotIn(self.tokenizer.eos_token, result)
def test_tokenizer_truncation(self):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text,
max_length=desired_max_length,
).input_ids[0]
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(len(ids), desired_max_length)
def test_mask_token(self):
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "ar_AR"]), [250053, 250001])
def test_special_tokens_unaffacted_by_save_load(self):
tmpdirname = tempfile.mkdtemp()
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
self.tokenizer.save_pretrained(tmpdirname)
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][0] == EN_CODE
assert batch.input_ids[1][-1] == 2
assert batch.labels[1][0] == RO_CODE
assert batch.labels[1][-1] == 2
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
@require_torch
def test_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, 0]) # decoder_start_token_id
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment