Unverified Commit ae1f8350 authored by Gunjan Chhablani's avatar Gunjan Chhablani Committed by GitHub
Browse files

Add PLBart (#13269)

* Init PLBART

* Add missing configuration file

* Add conversion script and configurationf ile

* Fix style

* Update modeling and conversion scripts

* Fix scale embedding in config

* Add comment

* Fix conversion script

* Add classification option to conversion script

* Fix vocab size in config doc

* Add tokenizer files from MBart50

* Allow no lang code in regular tokenizer

* Add PLBart Tokenizer Converters

* Remove mask from multi tokenizer

* Remove mask from multi tokenizer

* Change from MBart-50 to MBart tokenizer

* Fix names and modify src/tgt behavior

* Fix imports for tokenizer

* Remove <mask> from multi tokenizer

* Fix style

* Change tokenizer_class to processor_class

* Add attribute map to config class

* Update modeling file to modified MBart code

* Update configuration file to MBart style configuration

* Fix tokenizer

* Separate tokenizers

* Fix error in tokenization auto

* Copy MBart tests

* Replace with MBart tokenization tests

* Fix style

* Fix language code in multi tokenizer

* Fix configuration docs

* Add entry for plbart_multi in transformers init

* Add dummy objects and fix imports

* Fix modeling tests

* Add TODO in config

* Fix copyright year

* Fix modeling docs and test

* Fix some tokenization tests and style

* Add changes from review

* Fix copies

* Fix docs

* Fix docs

* Fix style

* Fix year

* Add changes from review

* Remove extra changes

* Fix base tokenizer and doc

* Fix style

* Fix modeling and slow tokenizer tests

* Remove Multi-tokenizer Converter and Tests

* Delete QA model and Multi Tokenizer dummy objects

* Fix repo consistency and code quality issues

* Fix example documentation

* Fix style

* Remove PLBartTokenizer from type checking in init

* Fix consistency issue

* Add changes from review

* Fix style

* Remove PLBartTokenizerFast

* Remove FastTokenizer converter

* Fix AutoTokenzier mapping

* Add plbart to toctree and fix consistency issues

* Add language codes tokenizer test

* Fix styling and doc issues

* Add fixes for failing tests

* Fix copies

* Fix failing modeling test

* Change assert to assertTrue in modeling tests
parent 2f2fefd6
...@@ -246,6 +246,8 @@ ...@@ -246,6 +246,8 @@
title: Pegasus title: Pegasus
- local: model_doc/phobert - local: model_doc/phobert
title: PhoBERT title: PhoBERT
- local: model_doc/plbart
title: PLBart
- local: model_doc/poolformer - local: model_doc/poolformer
title: PoolFormer title: PoolFormer
- local: model_doc/prophetnet - local: model_doc/prophetnet
......
...@@ -215,6 +215,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -215,6 +215,7 @@ Flax), PyTorch, and/or TensorFlow.
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ | | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ | | Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ |
| Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ | | Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ |
| PLBart | ✅ | ❌ | ✅ | ❌ | ❌ |
| PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
| QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ | | QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PLBart
**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title) and assign
[@gchhablani](https://www.github.com/gchhablani).
## Overview of PLBart
The PLBART model was proposed in [Unified Pre-training for Program Understanding and Generation](https://arxiv.org/abs/2103.06333) by Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang.
This is a BART-like model which can be used to perform code-summarization, code-generation, and code-translation tasks. The pre-trained model `plbart-base` has been trained using multilingual denoising task
on Java, Python and English.
According to the abstract
*Code summarization and generation empower conversion between programming language (PL) and natural language (NL),
while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART,
a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks.
PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding.
Experiments on code summarization in the English language, code generation, and code translation in seven programming languages
show that PLBART outperforms or rivals state-of-the-art models. Moreover, experiments on discriminative tasks, e.g., program
repair, clone detection, and vulnerable code detection, demonstrate PLBART's effectiveness in program understanding.
Furthermore, analysis reveals that PLBART learns program syntax, style (e.g., identifier naming convention), logical flow
(e.g., if block inside an else block is equivalent to else if block) that are crucial to program semantics and thus excels
even with limited annotations.*
This model was contributed by [gchhablani](https://huggingface.co/gchhablani). The Authors' code can be found [here](https://github.com/wasiahmad/PLBART).
### Training of PLBart
PLBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for code-to-text, text-to-code, code-to-code tasks. As the
model is multilingual it expects the sequences in a different format. A special language id token is added in both the
source and target text. The source text format is `X [eos, src_lang_code]` where `X` is the source text. The
target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this.
In cases where the language code is needed, The regular [`~PLBartTokenizer.__call__`] will encode source text format, and it should be wrapped
inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode target text format.
- Supervised training
```python
>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-base", src_lang="en_XX", tgt_lang="python")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> expected_translation_english = "Returns the maximum value of a b c."
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_english, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
>>> # forward pass
>>> model(**inputs)
```
- Generation
While generating the target text set the `decoder_start_token_id` to the target language id. The following
example shows how to translate Python to English using the `uclanlp/plbart-python-en_XX` model.
```python
>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
>>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-python-en_XX")
>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"])
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
"Returns the maximum value of a b c."
```
## PLBartConfig
[[autodoc]] PLBartConfig
## PLBartTokenizer
[[autodoc]] PLBartTokenizer
- as_target_tokenizer
- build_inputs_with_special_tokens
## PLBartModel
[[autodoc]] PLBartModel
- forward
## PLBartForConditionalGeneration
[[autodoc]] PLBartForConditionalGeneration
- forward
## PLBartForSequenceClassification
[[autodoc]] PLBartForSequenceClassification
- forward
## PLBartForCausalLM
[[autodoc]] PLBartForCausalLM
- forward
\ No newline at end of file
...@@ -57,6 +57,7 @@ Ready-made configurations include the following architectures: ...@@ -57,6 +57,7 @@ Ready-made configurations include the following architectures:
- Marian - Marian
- mBART - mBART
- OpenAI GPT-2 - OpenAI GPT-2
- PLBart
- RoBERTa - RoBERTa
- T5 - T5
- XLM-RoBERTa - XLM-RoBERTa
......
...@@ -263,6 +263,7 @@ _import_structure = { ...@@ -263,6 +263,7 @@ _import_structure = {
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"],
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"],
"models.phobert": ["PhobertTokenizer"], "models.phobert": ["PhobertTokenizer"],
"models.plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
"models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], "models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"],
"models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], "models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"],
"models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"], "models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
...@@ -410,6 +411,7 @@ if is_sentencepiece_available(): ...@@ -410,6 +411,7 @@ if is_sentencepiece_available():
_import_structure["models.mluke"].append("MLukeTokenizer") _import_structure["models.mluke"].append("MLukeTokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.plbart"].append("PLBartTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer")
_import_structure["models.rembert"].append("RemBertTokenizer") _import_structure["models.rembert"].append("RemBertTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
...@@ -1219,6 +1221,16 @@ if is_torch_available(): ...@@ -1219,6 +1221,16 @@ if is_torch_available():
"PerceiverPreTrainedModel", "PerceiverPreTrainedModel",
] ]
) )
_import_structure["models.plbart"].extend(
[
"PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"PLBartForCausalLM",
"PLBartForConditionalGeneration",
"PLBartForSequenceClassification",
"PLBartModel",
"PLBartPreTrainedModel",
]
)
_import_structure["models.poolformer"].extend( _import_structure["models.poolformer"].extend(
[ [
"POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2498,6 +2510,7 @@ if TYPE_CHECKING: ...@@ -2498,6 +2510,7 @@ if TYPE_CHECKING:
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer
from .models.phobert import PhobertTokenizer from .models.phobert import PhobertTokenizer
from .models.plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig
from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer
from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
...@@ -2630,6 +2643,7 @@ if TYPE_CHECKING: ...@@ -2630,6 +2643,7 @@ if TYPE_CHECKING:
from .models.mluke import MLukeTokenizer from .models.mluke import MLukeTokenizer
from .models.mt5 import MT5Tokenizer from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer from .models.pegasus import PegasusTokenizer
from .models.plbart import PLBartTokenizer
from .models.reformer import ReformerTokenizer from .models.reformer import ReformerTokenizer
from .models.rembert import RemBertTokenizer from .models.rembert import RemBertTokenizer
from .models.speech_to_text import Speech2TextTokenizer from .models.speech_to_text import Speech2TextTokenizer
...@@ -3292,6 +3306,14 @@ if TYPE_CHECKING: ...@@ -3292,6 +3306,14 @@ if TYPE_CHECKING:
PerceiverModel, PerceiverModel,
PerceiverPreTrainedModel, PerceiverPreTrainedModel,
) )
from .models.plbart import (
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
PLBartForCausalLM,
PLBartForConditionalGeneration,
PLBartForSequenceClassification,
PLBartModel,
PLBartPreTrainedModel,
)
from .models.poolformer import ( from .models.poolformer import (
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
PoolFormerForImageClassification, PoolFormerForImageClassification,
......
...@@ -83,6 +83,7 @@ from . import ( ...@@ -83,6 +83,7 @@ from . import (
pegasus, pegasus,
perceiver, perceiver,
phobert, phobert,
plbart,
poolformer, poolformer,
prophetnet, prophetnet,
qdqbert, qdqbert,
......
...@@ -49,6 +49,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -49,6 +49,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("perceiver", "PerceiverConfig"), ("perceiver", "PerceiverConfig"),
("gptj", "GPTJConfig"), ("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"), ("layoutlmv2", "LayoutLMv2Config"),
("plbart", "PLBartConfig"),
("beit", "BeitConfig"), ("beit", "BeitConfig"),
("rembert", "RemBertConfig"), ("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"), ("visual_bert", "VisualBertConfig"),
...@@ -143,6 +144,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ...@@ -143,6 +144,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
...@@ -228,6 +230,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -228,6 +230,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("perceiver", "Perceiver"), ("perceiver", "Perceiver"),
("gptj", "GPT-J"), ("gptj", "GPT-J"),
("beit", "BEiT"), ("beit", "BEiT"),
("plbart", "PLBart"),
("rembert", "RemBERT"), ("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"), ("layoutlmv2", "LayoutLMv2"),
("visual_bert", "VisualBert"), ("visual_bert", "VisualBert"),
......
...@@ -44,6 +44,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -44,6 +44,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("perceiver", "PerceiverModel"), ("perceiver", "PerceiverModel"),
("gptj", "GPTJModel"), ("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"), ("layoutlmv2", "LayoutLMv2Model"),
("plbart", "PLBartModel"),
("beit", "BeitModel"), ("beit", "BeitModel"),
("rembert", "RemBertModel"), ("rembert", "RemBertModel"),
("visual_bert", "VisualBertModel"), ("visual_bert", "VisualBertModel"),
...@@ -163,6 +164,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -163,6 +164,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
# Model with LM heads mapping # Model with LM heads mapping
("yoso", "YosoForMaskedLM"), ("yoso", "YosoForMaskedLM"),
("nystromformer", "NystromformerForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"),
("plbart", "PLBartForConditionalGeneration"),
("qdqbert", "QDQBertForMaskedLM"), ("qdqbert", "QDQBertForMaskedLM"),
("fnet", "FNetForMaskedLM"), ("fnet", "FNetForMaskedLM"),
("gptj", "GPTJForCausalLM"), ("gptj", "GPTJForCausalLM"),
...@@ -216,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -216,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Causal LM mapping # Model for Causal LM mapping
("xglm", "XGLMForCausalLM"), ("xglm", "XGLMForCausalLM"),
("plbart", "PLBartForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"), ("qdqbert", "QDQBertLMHeadModel"),
("trocr", "TrOCRForCausalLM"), ("trocr", "TrOCRForCausalLM"),
("gptj", "GPTJForCausalLM"), ("gptj", "GPTJForCausalLM"),
...@@ -361,6 +364,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( ...@@ -361,6 +364,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
("plbart", "PLBartForConditionalGeneration"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"),
("led", "LEDForConditionalGeneration"), ("led", "LEDForConditionalGeneration"),
...@@ -391,6 +395,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -391,6 +395,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
("yoso", "YosoForSequenceClassification"), ("yoso", "YosoForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"),
("plbart", "PLBartForSequenceClassification"),
("perceiver", "PerceiverForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"), ("qdqbert", "QDQBertForSequenceClassification"),
("fnet", "FNetForSequenceClassification"), ("fnet", "FNetForSequenceClassification"),
......
...@@ -47,6 +47,7 @@ if TYPE_CHECKING: ...@@ -47,6 +47,7 @@ if TYPE_CHECKING:
else: else:
TOKENIZER_MAPPING_NAMES = OrderedDict( TOKENIZER_MAPPING_NAMES = OrderedDict(
[ [
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
}
if is_sentencepiece_available():
_import_structure["tokenization_plbart"] = ["PLBartTokenizer"]
if is_torch_available():
_import_structure["modeling_plbart"] = [
"PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"PLBartForCausalLM",
"PLBartForConditionalGeneration",
"PLBartForSequenceClassification",
"PLBartModel",
"PLBartPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
if is_sentencepiece_available():
from .tokenization_plbart import PLBartTokenizer
if is_torch_available():
from .modeling_plbart import (
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
PLBartForCausalLM,
PLBartForConditionalGeneration,
PLBartForSequenceClassification,
PLBartModel,
PLBartPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PLBART model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
logger = logging.get_logger(__name__)
PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/config.json",
# See all PLBART models at https://huggingface.co/models?filter=plbart
}
class PLBartConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PLBartModel`]. It is used to instantiate an
PLBART model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PLBART
[uclanlp/plbart-base](https://huggingface.co/uclanlp/plbart-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50005):
Vocabulary size of the PLBART model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PLBartModel`].
d_model (`int`, *optional*, defaults to 768):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 6):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 6):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `True`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import PLBartModel, PLBartConfig
>>> # Initializing a PLBART uclanlp/plbart-base style configuration
>>> configuration = PLBartConfig()
>>> # Initializing a model from the uclanlp/plbart-base style configuration
>>> model = PLBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "plbart"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=50005,
max_position_embeddings=1024,
encoder_layers=6,
encoder_ffn_dim=3072,
encoder_attention_heads=12,
decoder_layers=6,
decoder_ffn_dim=3072,
decoder_attention_heads=12,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=768,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
forced_eos_token_id=2,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
class PLBartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from torch import nn
from transformers import PLBartConfig, PLBartForConditionalGeneration, PLBartForSequenceClassification
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
"decoder.output_projection.weight",
]
for k in ignore_keys:
state_dict.pop(k, None)
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def convert_fairseq_plbart_checkpoint_from_disk(
checkpoint_path, hf_config_path="uclanlp/plbart-base", finetuned=False, classification=False
):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
plbart_config = PLBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
if not classification:
model = PLBartForConditionalGeneration(plbart_config)
model.model.load_state_dict(state_dict)
if finetuned:
model.lm_head = make_linear_from_emb(model.model.shared)
else:
classification_head = {}
for key, value in state_dict.copy().items():
if key.startswith("classification_heads.sentence_classification_head"):
classification_head[key.replace("classification_heads.sentence_classification_head.", "")] = value
state_dict.pop(key)
model = PLBartForSequenceClassification(plbart_config)
model.model.load_state_dict(state_dict)
model.classification_head.load_state_dict(classification_head)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", type=str, help="model.pt on local filesystem.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config",
default="uclanlp/plbart-base",
type=str,
help="Which huggingface architecture to use: plbart-base",
)
parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")
parser.add_argument(
"--classification", action="store_true", help="whether the model is a classification checkpoint"
)
args = parser.parse_args()
model = convert_fairseq_plbart_checkpoint_from_disk(
args.fairseq_path,
hf_config_path=args.hf_config,
finetuned=args.finetuned,
classification=args.classification,
)
model.save_pretrained(args.pytorch_dump_folder_path)
This diff is collapsed.
This diff is collapsed.
...@@ -2786,6 +2786,44 @@ class PerceiverPreTrainedModel(metaclass=DummyObject): ...@@ -2786,6 +2786,44 @@ class PerceiverPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = None
class PLBartForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PLBartForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PLBartForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PLBartModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PLBartPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -108,6 +108,13 @@ class PegasusTokenizer(metaclass=DummyObject): ...@@ -108,6 +108,13 @@ class PegasusTokenizer(metaclass=DummyObject):
requires_backends(self, ["sentencepiece"]) requires_backends(self, ["sentencepiece"])
class PLBartTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
class ReformerTokenizer(metaclass=DummyObject): class ReformerTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"] _backends = ["sentencepiece"]
......
This diff is collapsed.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, PLBartTokenizer, is_torch_available
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
if is_torch_available():
from transformers.models.plbart.modeling_plbart import shift_tokens_right
EN_CODE = 50003
PYTHON_CODE = 50002
@require_sentencepiece
@require_tokenizers
class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = PLBartTokenizer
rust_tokenizer_class = None
test_rust_tokenizer = False
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="base", keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def test_full_base_tokenizer(self):
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="base", keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[
value + tokenizer.fairseq_offset
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
end = tokenizer.vocab_size
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)]
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "<mask>"])
def test_full_multi_tokenizer(self):
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[
value + tokenizer.fairseq_offset
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
end = tokenizer.vocab_size
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)]
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "javascript", "php", "ruby", "go"])
@require_torch
@require_sentencepiece
@require_tokenizers
class PLBartPythonEnIntegrationTest(unittest.TestCase):
checkpoint_name = "uclanlp/plbart-python-en_XX"
src_text = [
"def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])",
"def sum(a,b,c):NEW_LINE_INDENTreturn sum([a,b,c])",
]
tgt_text = [
"Returns the maximum value of a b c.",
"Sums the values of a b c.",
]
expected_src_tokens = [
134,
5452,
33460,
33441,
33463,
33465,
33463,
33449,
988,
20,
33456,
19,
33456,
771,
39,
4258,
889,
3318,
33441,
33463,
33465,
33463,
33449,
2471,
2,
PYTHON_CODE,
]
@classmethod
def setUpClass(cls):
cls.tokenizer: PLBartTokenizer = PLBartTokenizer.from_pretrained(
cls.checkpoint_name, language_codes="base", src_lang="python", tgt_lang="en_XX"
)
cls.pad_token_id = 1
return cls
def check_language_codes(self):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["java"], 50001)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["python"], 50002)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_XX"], 50003)
def test_python_en_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
def test_python_en_tokenizer_decode_ignores_language_codes(self):
self.assertIn(PYTHON_CODE, self.tokenizer.all_special_ids)
generated_ids = [EN_CODE, 9037, 33442, 57, 752, 153, 14, 56, 18, 9, 2]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
expected_english = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
self.assertEqual(result, expected_english)
self.assertNotIn(self.tokenizer.eos_token, result)
def test_python_en_tokenizer_truncation(self):
src_text = ["def sum(a,b,c):NEW_LINE_INDENTreturn sum([a,b,c])" * 20]
self.assertIsInstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], PYTHON_CODE)
self.assertEqual(len(ids), desired_max_length)
def test_mask_token(self):
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "java"]), [50004, 50001])
def test_special_tokens_unaffacted_by_save_load(self):
tmpdirname = tempfile.mkdtemp()
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
self.tokenizer.save_pretrained(tmpdirname)
new_tok = PLBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
@require_torch
def test_batch_fairseq_parity(self):
batch = self.tokenizer(self.src_text, padding=True)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
@require_torch
def test_python_en_tokenizer_prepare_batch(self):
batch = self.tokenizer(
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
)
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(
self.tgt_text,
padding=True,
truncation=True,
max_length=len(self.expected_src_tokens),
return_tensors="pt",
)
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 26), batch.input_ids.shape)
self.assertEqual((2, 26), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, PYTHON_CODE])
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
with self.tokenizer.as_target_tokenizer():
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs(
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="java"
)
self.assertEqual(
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[150, 242, 2, 50003]],
"attention_mask": [[1, 1, 1, 1]],
# java
"forced_bos_token_id": 50001,
},
)
...@@ -45,6 +45,9 @@ PRIVATE_MODELS = [ ...@@ -45,6 +45,9 @@ PRIVATE_MODELS = [
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested # models to ignore for not tested
"SegformerDecodeHead", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model.
"PLBartDecoder", # Building part of bigger (tested) model.
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
"BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model. "BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
...@@ -119,6 +122,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -119,6 +122,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"PerceiverForOpticalFlow", "PerceiverForOpticalFlow",
"SegformerDecodeHead", "SegformerDecodeHead",
"FlaxBeitForMaskedImageModeling", "FlaxBeitForMaskedImageModeling",
"PLBartEncoder",
"PLBartDecoder",
"PLBartDecoderWrapper",
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment