"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ba1b24e07bebc8e36b464bf7a403feb4f3ccb807"
Unverified Commit eef66035 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PyTorch Bart] Split Bart into different models (#9343)

* first try

* remove old template

* finish bart

* finish mbart

* delete unnecessary line

* init pegasus

* save intermediate

* correct pegasus

* finish pegasus

* remove cookie cutter leftover

* add marian

* finish blenderbot

* replace in file

* correctly split blenderbot

* delete "old" folder

* correct "add statement"

* adapt config for tf comp

* correct configs for tf

* remove ipdb

* fix more stuff

* fix mbart

* push pegasus fix

* fix mbart

* more fixes

* fix research projects code

* finish docs for bart, mbart, and marian

* delete unnecessary file

* correct attn typo

* correct configs

* remove pegasus for seq class

* correct peg docs

* correct peg docs

* finish configs

* further improve docs

* add copied from statements to mbart

* fix copied from in mbart

* add copy statements to marian

* add copied from to marian

* add pegasus copied from

* finish pegasus

* finish copied from

* Apply suggestions from code review

* make style

* backward comp blenderbot

* apply lysandres and sylvains suggestions

* apply suggestions

* push last fixes

* fix docs

* fix tok tests

* fix imports code style

* fix doc
parent 4eec5d0c
...@@ -220,6 +220,8 @@ TensorFlow and/or Flax. ...@@ -220,6 +220,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ | | Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
...@@ -361,6 +363,7 @@ TensorFlow and/or Flax. ...@@ -361,6 +363,7 @@ TensorFlow and/or Flax.
model_doc/bertweet model_doc/bertweet
model_doc/bertgeneration model_doc/bertgeneration
model_doc/blenderbot model_doc/blenderbot
model_doc/blenderbot_small
model_doc/camembert model_doc/camembert
model_doc/ctrl model_doc/ctrl
model_doc/deberta model_doc/deberta
......
...@@ -64,7 +64,6 @@ Implementation Notes ...@@ -64,7 +64,6 @@ Implementation Notes
summarization, see the example in that docstrings. summarization, see the example in that docstrings.
- Models that load the `facebook/bart-large-cnn` weights will not have a :obj:`mask_token_id`, or be able to perform - Models that load the `facebook/bart-large-cnn` weights will not have a :obj:`mask_token_id`, or be able to perform
mask-filling tasks. mask-filling tasks.
- For training/forward passes that don't involve beam search, pass :obj:`use_cache=False`.
Mask Filling Mask Filling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -43,13 +43,10 @@ Implementation Notes ...@@ -43,13 +43,10 @@ Implementation Notes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Blenderbot uses a standard `seq2seq model transformer <https://arxiv.org/pdf/1706.03762.pdf>`__ based architecture. - Blenderbot uses a standard `seq2seq model transformer <https://arxiv.org/pdf/1706.03762.pdf>`__ based architecture.
- It inherits completely from :class:`~transformers.BartForConditionalGeneration`
- Even though blenderbot is one model, it uses two tokenizers :class:`~transformers.BlenderbotSmallTokenizer` for 90M
checkpoint and :class:`~transformers.BlenderbotTokenizer` for all other checkpoints.
- :class:`~transformers.BlenderbotSmallTokenizer` will always return :class:`~transformers.BlenderbotSmallTokenizer`,
regardless of checkpoint. To use the 3B parameter checkpoint, you must call
:class:`~transformers.BlenderbotTokenizer` directly.
- Available checkpoints can be found in the `model hub <https://huggingface.co/models?search=blenderbot>`__. - Available checkpoints can be found in the `model hub <https://huggingface.co/models?search=blenderbot>`__.
- This is the `default` Blenderbot model class. However, some smaller checkpoints, such as
``facebook/blenderbot_small_90M``, have a different architecture and consequently should be used with
`BlenderbotSmall <https://huggingface.co/transformers/master/model_doc/blenderbot_small.html>`__.
Usage Usage
...@@ -59,26 +56,15 @@ Here is an example of model usage: ...@@ -59,26 +56,15 @@ Here is an example of model usage:
.. code-block:: .. code-block::
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotForConditionalGeneration >>> from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
>>> mname = 'facebook/blenderbot-90M' >>> mname = 'facebook/blenderbot-400M-distill'
>>> model = BlenderbotForConditionalGeneration.from_pretrained(mname) >>> model = BlenderbotForConditionalGeneration.from_pretrained(mname)
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained(mname) >>> tokenizer = BlenderbotTokenizer.from_pretrained(mname)
>>> UTTERANCE = "My friends are cool but they eat too many carbs." >>> UTTERANCE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([UTTERANCE], return_tensors='pt') >>> inputs = tokenizer([UTTERANCE], return_tensors='pt')
>>> reply_ids = model.generate(**inputs) >>> reply_ids = model.generate(**inputs)
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in reply_ids]) >>> print(tokenizer.batch_decode(reply_ids))
["<s> That's unfortunate. Are they trying to lose weight or are they just trying to be healthier?</s>"]
Here is how you can check out config values:
.. code-block::
>>> from transformers import BlenderbotConfig
>>> config_90 = BlenderbotConfig.from_pretrained("facebook/blenderbot-90M")
>>> config_90.to_diff_dict() # show interesting Values.
>>> configuration_3B = BlenderbotConfig("facebook/blenderbot-3B")
>>> configuration_3B.to_diff_dict()
BlenderbotConfig BlenderbotConfig
...@@ -93,12 +79,6 @@ BlenderbotTokenizer ...@@ -93,12 +79,6 @@ BlenderbotTokenizer
.. autoclass:: transformers.BlenderbotTokenizer .. autoclass:: transformers.BlenderbotTokenizer
:members: build_inputs_with_special_tokens :members: build_inputs_with_special_tokens
BlenderbotSmallTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotSmallTokenizer
:members:
BlenderbotModel BlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -106,7 +86,7 @@ BlenderbotModel ...@@ -106,7 +86,7 @@ BlenderbotModel
See :obj:`transformers.BartModel` for arguments to `forward` and `generate` See :obj:`transformers.BartModel` for arguments to `forward` and `generate`
.. autoclass:: transformers.BlenderbotModel .. autoclass:: transformers.BlenderbotModel
:members: :members: forward
BlenderbotForConditionalGeneration BlenderbotForConditionalGeneration
...@@ -115,7 +95,7 @@ BlenderbotForConditionalGeneration ...@@ -115,7 +95,7 @@ BlenderbotForConditionalGeneration
See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` and `generate` See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` and `generate`
.. autoclass:: transformers.BlenderbotForConditionalGeneration .. autoclass:: transformers.BlenderbotForConditionalGeneration
:members: :members: forward
TFBlenderbotForConditionalGeneration TFBlenderbotForConditionalGeneration
......
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Blenderbot Small
-----------------------------------------------------------------------------------------------------------------------
Note that :class:`~transformers.BlenderbotSmallModel` and
:class:`~transformers.BlenderbotSmallForConditionalGeneration` are only used in combination with the checkpoint
`facebook/blenderbot-90M <https://huggingface.co/facebook/blenderbot-90M>`__. Larger Blenderbot checkpoints should
instead be used with :class:`~transformers.BlenderbotModel` and
:class:`~transformers.BlenderbotForConditionalGeneration`
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Blender chatbot model was proposed in `Recipes for building an open-domain chatbot
<https://arxiv.org/pdf/2004.13637.pdf>`__ Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu,
Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston on 30 Apr 2020.
The abstract of the paper is the following:
*Building open-domain chatbots is a challenging area for machine learning research. While prior work has shown that
scaling neural models in the number of parameters and the size of the data they are trained on gives improved results,
we show that other ingredients are important for a high-performing chatbot. Good conversation requires a number of
skills that an expert conversationalist blends in a seamless way: providing engaging talking points and listening to
their partners, and displaying knowledge, empathy and personality appropriately, while maintaining a consistent
persona. We show that large scale models can learn these skills when given appropriate training data and choice of
generation strategy. We build variants of these recipes with 90M, 2.7B and 9.4B parameter models, and make our models
and code publicly available. Human evaluations show our best models are superior to existing approaches in multi-turn
dialogue in terms of engagingness and humanness measurements. We then discuss the limitations of this work by analyzing
failure cases of our models.*
The authors' code can be found `here <https://github.com/facebookresearch/ParlAI>`__ .
BlenderbotSmallConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotSmallConfig
:members:
BlenderbotSmallTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotSmallTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary
BlenderbotSmallModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotSmallModel
:members: forward
BlenderbotSmallForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
:members: forward
...@@ -33,7 +33,6 @@ Implementation Notes ...@@ -33,7 +33,6 @@ Implementation Notes
- The modeling code is the same as :class:`~transformers.BartForConditionalGeneration` with a few minor modifications: - The modeling code is the same as :class:`~transformers.BartForConditionalGeneration` with a few minor modifications:
- static (sinusoid) positional embeddings (:obj:`MarianConfig.static_position_embeddings=True`) - static (sinusoid) positional embeddings (:obj:`MarianConfig.static_position_embeddings=True`)
- a new final_logits_bias (:obj:`MarianConfig.add_bias_logits=True`)
- no layernorm_embedding (:obj:`MarianConfig.normalize_embedding=False`) - no layernorm_embedding (:obj:`MarianConfig.normalize_embedding=False`)
- the model starts generating with :obj:`pad_token_id` (which has 0 as a token_embedding) as the prefix (Bart uses - the model starts generating with :obj:`pad_token_id` (which has 0 as a token_embedding) as the prefix (Bart uses
:obj:`<s/>`), :obj:`<s/>`),
...@@ -56,9 +55,10 @@ Examples ...@@ -56,9 +55,10 @@ Examples
- Since Marian models are smaller than many other translation models available in the library, they can be useful for - Since Marian models are smaller than many other translation models available in the library, they can be useful for
fine-tuning experiments and integration tests. fine-tuning experiments and integration tests.
- :prefix_link:`Fine-tune on TPU <examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh>` - `Fine-tune on GPU
- :prefix_link:`Fine-tune on GPU <examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh>` <https://github.com/huggingface/transformers/blob/master/examples/research_projects/seq2seq-distillation/train_distil_marian_enro_teacher.sh>`__
- :prefix_link:`Fine-tune on GPU with pytorch-lightning <examples/seq2seq/distil_marian_no_teacher.sh>` - `Fine-tune on GPU with pytorch-lightning
<https://github.com/huggingface/transformers/blob/master/examples/research_projects/seq2seq-distillation/train_distil_marian_no_teacher.sh>`__
Multilingual Models Multilingual Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -179,10 +179,18 @@ MarianTokenizer ...@@ -179,10 +179,18 @@ MarianTokenizer
:members: prepare_seq2seq_batch :members: prepare_seq2seq_batch
MarianModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianModel
:members: forward
MarianMTModel MarianMTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianMTModel .. autoclass:: transformers.MarianMTModel
:members: forward
TFMarianMTModel TFMarianMTModel
......
...@@ -111,6 +111,19 @@ MBartForConditionalGeneration ...@@ -111,6 +111,19 @@ MBartForConditionalGeneration
:members: :members:
MBartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartForQuestionAnswering
:members:
MBartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartForSequenceClassification
TFMBartForConditionalGeneration TFMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -65,7 +65,6 @@ Implementation Notes ...@@ -65,7 +65,6 @@ Implementation Notes
- Some key configuration differences: - Some key configuration differences:
- static, sinusoidal position embeddings - static, sinusoidal position embeddings
- no :obj:`layernorm_embedding` (:obj:`PegasusConfig.normalize_embedding=False`)
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. - the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix.
- more beams are used (:obj:`num_beams=8`) - more beams are used (:obj:`num_beams=8`)
- All pretrained pegasus checkpoints are the same besides three attributes: :obj:`tokenizer.model_max_length` (maximum - All pretrained pegasus checkpoints are the same besides three attributes: :obj:`tokenizer.model_max_length` (maximum
...@@ -122,12 +121,14 @@ PegasusModel ...@@ -122,12 +121,14 @@ PegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PegasusModel .. autoclass:: transformers.PegasusModel
:members: forward
PegasusForConditionalGeneration PegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PegasusForConditionalGeneration .. autoclass:: transformers.PegasusForConditionalGeneration
:members: forward
TFPegasusForConditionalGeneration TFPegasusForConditionalGeneration
......
...@@ -23,7 +23,7 @@ from pack_dataset import pack_data_dir ...@@ -23,7 +23,7 @@ from pack_dataset import pack_data_dir
from parameterized import parameterized from parameterized import parameterized
from save_len_file import save_len_file from save_len_file import save_len_file
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.bart.modeling_bart import shift_tokens_right from transformers.models.mbart.modeling_mbart import shift_tokens_right
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
......
...@@ -33,9 +33,8 @@ from torch import nn ...@@ -33,9 +33,8 @@ from torch import nn
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try: try:
...@@ -305,15 +304,9 @@ class Seq2SeqDataCollator: ...@@ -305,15 +304,9 @@ class Seq2SeqDataCollator:
labels = trim_batch(labels, self.pad_token_id) labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = { batch = {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels, "labels": labels,
} }
return batch return batch
......
...@@ -120,11 +120,11 @@ from .models.bert import ( ...@@ -120,11 +120,11 @@ from .models.bert import (
from .models.bert_generation import BertGenerationConfig from .models.bert_generation import BertGenerationConfig
from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .models.bertweet import BertweetTokenizer from .models.bertweet import BertweetTokenizer
from .models.blenderbot import ( from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, from .models.blenderbot_small import (
BlenderbotConfig, BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotSmallConfig,
BlenderbotSmallTokenizer, BlenderbotSmallTokenizer,
BlenderbotTokenizer,
) )
from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer
...@@ -415,6 +415,11 @@ if is_torch_available(): ...@@ -415,6 +415,11 @@ if is_torch_available():
BlenderbotForConditionalGeneration, BlenderbotForConditionalGeneration,
BlenderbotModel, BlenderbotModel,
) )
from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel,
)
from .models.camembert import ( from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM, CamembertForCausalLM,
...@@ -536,8 +541,13 @@ if is_torch_available(): ...@@ -536,8 +541,13 @@ if is_torch_available():
LxmertVisualFeatureEncoder, LxmertVisualFeatureEncoder,
LxmertXLayer, LxmertXLayer,
) )
from .models.marian import MarianMTModel from .models.marian import MarianModel, MarianMTModel
from .models.mbart import MBartForConditionalGeneration, MBartModel from .models.mbart import (
MBartForConditionalGeneration,
MBartForQuestionAnswering,
MBartForSequenceClassification,
MBartModel,
)
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import ( from .models.mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -23,6 +23,10 @@ from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartCo ...@@ -23,6 +23,10 @@ from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartCo
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig from ..bert_generation.configuration_bert_generation import BertGenerationConfig
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from ..blenderbot_small.configuration_blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotSmallConfig,
)
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
...@@ -68,6 +72,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ...@@ -68,6 +72,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
for pretrained_map in [ for pretrained_map in [
# Add archive maps here # Add archive maps here
LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -108,6 +113,7 @@ CONFIG_MAPPING = OrderedDict( ...@@ -108,6 +113,7 @@ CONFIG_MAPPING = OrderedDict(
[ [
# Add configs here # Add configs here
("led", LEDConfig), ("led", LEDConfig),
("blenderbot-small", BlenderbotSmallConfig),
("retribert", RetriBertConfig), ("retribert", RetriBertConfig),
("mt5", MT5Config), ("mt5", MT5Config),
("t5", T5Config), ("t5", T5Config),
...@@ -154,6 +160,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -154,6 +160,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
[ [
# Add full (and cased) model names here # Add full (and cased) model names here
("led", "LED"), ("led", "LED"),
("blenderbot-small", "BlenderbotSmall"),
("retribert", "RetriBERT"), ("retribert", "RetriBERT"),
("t5", "T5"), ("t5", "T5"),
("mobilebert", "MobileBERT"), ("mobilebert", "MobileBERT"),
......
...@@ -51,6 +51,7 @@ from ..bert.modeling_bert import ( ...@@ -51,6 +51,7 @@ from ..bert.modeling_bert import (
) )
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel
from ..blenderbot_small.modeling_blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel
from ..camembert.modeling_camembert import ( from ..camembert.modeling_camembert import (
CamembertForCausalLM, CamembertForCausalLM,
CamembertForMaskedLM, CamembertForMaskedLM,
...@@ -116,8 +117,13 @@ from ..longformer.modeling_longformer import ( ...@@ -116,8 +117,13 @@ from ..longformer.modeling_longformer import (
LongformerModel, LongformerModel,
) )
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from ..marian.modeling_marian import MarianMTModel from ..marian.modeling_marian import MarianModel, MarianMTModel
from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel from ..mbart.modeling_mbart import (
MBartForConditionalGeneration,
MBartForQuestionAnswering,
MBartForSequenceClassification,
MBartModel,
)
from ..mobilebert.modeling_mobilebert import ( from ..mobilebert.modeling_mobilebert import (
MobileBertForMaskedLM, MobileBertForMaskedLM,
MobileBertForMultipleChoice, MobileBertForMultipleChoice,
...@@ -215,6 +221,7 @@ from .configuration_auto import ( ...@@ -215,6 +221,7 @@ from .configuration_auto import (
BertConfig, BertConfig,
BertGenerationConfig, BertGenerationConfig,
BlenderbotConfig, BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
DebertaConfig, DebertaConfig,
...@@ -260,6 +267,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -260,6 +267,7 @@ MODEL_MAPPING = OrderedDict(
[ [
# Base model mapping # Base model mapping
(LEDConfig, LEDModel), (LEDConfig, LEDModel),
(BlenderbotSmallConfig, BlenderbotSmallModel),
(RetriBertConfig, RetriBertModel), (RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model), (MT5Config, MT5Model),
(T5Config, T5Model), (T5Config, T5Model),
...@@ -297,6 +305,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -297,6 +305,7 @@ MODEL_MAPPING = OrderedDict(
(ProphetNetConfig, ProphetNetModel), (ProphetNetConfig, ProphetNetModel),
(MPNetConfig, MPNetModel), (MPNetConfig, MPNetModel),
(TapasConfig, TapasModel), (TapasConfig, TapasModel),
(MarianConfig, MarianModel),
] ]
) )
...@@ -336,6 +345,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ...@@ -336,6 +345,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[ [
# Model with LM heads mapping # Model with LM heads mapping
(LEDConfig, LEDForConditionalGeneration), (LEDConfig, LEDForConditionalGeneration),
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
(LayoutLMConfig, LayoutLMForMaskedLM), (LayoutLMConfig, LayoutLMForMaskedLM),
(T5Config, T5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM),
...@@ -417,6 +427,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -417,6 +427,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
(LEDConfig, LEDForConditionalGeneration), (LEDConfig, LEDForConditionalGeneration),
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
(MT5Config, MT5ForConditionalGeneration), (MT5Config, MT5ForConditionalGeneration),
(T5Config, T5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration), (PegasusConfig, PegasusForConditionalGeneration),
...@@ -439,6 +450,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -439,6 +450,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(AlbertConfig, AlbertForSequenceClassification), (AlbertConfig, AlbertForSequenceClassification),
(CamembertConfig, CamembertForSequenceClassification), (CamembertConfig, CamembertForSequenceClassification),
(XLMRobertaConfig, XLMRobertaForSequenceClassification), (XLMRobertaConfig, XLMRobertaForSequenceClassification),
(MBartConfig, MBartForSequenceClassification),
(BartConfig, BartForSequenceClassification), (BartConfig, BartForSequenceClassification),
(LongformerConfig, LongformerForSequenceClassification), (LongformerConfig, LongformerForSequenceClassification),
(RobertaConfig, RobertaForSequenceClassification), (RobertaConfig, RobertaForSequenceClassification),
...@@ -469,6 +481,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( ...@@ -469,6 +481,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(AlbertConfig, AlbertForQuestionAnswering), (AlbertConfig, AlbertForQuestionAnswering),
(CamembertConfig, CamembertForQuestionAnswering), (CamembertConfig, CamembertForQuestionAnswering),
(BartConfig, BartForQuestionAnswering), (BartConfig, BartForQuestionAnswering),
(MBartConfig, MBartForQuestionAnswering),
(LongformerConfig, LongformerForQuestionAnswering), (LongformerConfig, LongformerForQuestionAnswering),
(XLMRobertaConfig, XLMRobertaForQuestionAnswering), (XLMRobertaConfig, XLMRobertaForQuestionAnswering),
(RobertaConfig, RobertaForQuestionAnswering), (RobertaConfig, RobertaForQuestionAnswering),
......
...@@ -24,7 +24,7 @@ from ..bart.tokenization_bart import BartTokenizer ...@@ -24,7 +24,7 @@ from ..bart.tokenization_bart import BartTokenizer
from ..bert.tokenization_bert import BertTokenizer from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotSmallTokenizer from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer from ..deberta.tokenization_deberta import DebertaTokenizer
from ..distilbert.tokenization_distilbert import DistilBertTokenizer from ..distilbert.tokenization_distilbert import DistilBertTokenizer
...@@ -197,12 +197,12 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -197,12 +197,12 @@ TOKENIZER_MAPPING = OrderedDict(
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)), (AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
(CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)), (CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)),
(PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)), (PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
(MBartConfig, (BarthezTokenizer, BarthezTokenizerFast)),
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)), (MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)), (XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)), (MarianConfig, (MarianTokenizer, None)),
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)), (BlenderbotConfig, (BlenderbotSmallTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BarthezTokenizer, BarthezTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)), (BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_bart import BartConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from .tokenization_bart import BartTokenizer from .tokenization_bart import BartTokenizer
......
# coding=utf-8 # coding=utf-8
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BART configuration """ """ BART model configuration """
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -21,34 +21,33 @@ from ...utils import logging ...@@ -21,34 +21,33 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/config.json",
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json", "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json", # See all BART models at https://huggingface.co/models?filter=bart
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json",
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/config.json",
"facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/config.json",
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/config.json",
} }
class BartConfig(PretrainedConfig): class BartConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a :class:`~transformers.BartModel`. It is used to This is the configuration class to store the configuration of a :class:`~transformers.BartModel`. It is used to
instantiate a BART model according to the specified arguments, defining the model architecture. instantiate a BART model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the BART `facebook/bart-large
<https://huggingface.co/facebook/bart-large>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args: Args:
vocab_size (:obj:`int`, `optional`, defaults to 50265): vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BartModel`. :obj:`inputs_ids` passed when calling :class:`~transformers.BartModel` or
:class:`~transformers.TFBartModel`.
d_model (:obj:`int`, `optional`, defaults to 1024): d_model (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
encoder_layers (:obj:`int`, `optional`, defaults to 12): encoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of encoder layers, 6 are used for the `bart-base` model. Number of encoder layers.
decoder_layers (:obj:`int`, `optional`, defaults to 12): decoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of decoder layers, 6 are used for the `bart-base` model. Number of decoder layers.
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16): encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
...@@ -73,145 +72,113 @@ class BartConfig(PretrainedConfig): ...@@ -73,145 +72,113 @@ class BartConfig(PretrainedConfig):
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, `optional`, defaults to 0.02): init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
add_bias_logits (:obj:`bool`, `optional`, defaults to :obj:`False`): force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
This should be completed, specific to marian. Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only
normalize_before (:obj:`bool`, `optional`, defaults to :obj:`False`): :obj:`True` for `bart-large-cnn`.
Call layernorm before attention ops.
normalize_embedding (:obj:`bool`, `optional`, defaults to :obj:`True`):
Call layernorm after embeddings.
static_position_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Don't learn positional embeddings, use sinusoidal.
add_final_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
Why not add another layernorm?
do_blenderbot_90_layernorm (:obj:`bool`, `optional`, defaults to :obj:`False`):
Blenderbot-90m checkpoint uses `layernorm_embedding` one line earlier in the decoder.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
eos_token_id (:obj:`int`, `optional`, defaults to 2)
End of stream token id.
pad_token_id (:obj:`int`, `optional`, defaults to 1)
Padding token id.
bos_token_id (:obj:`int`, `optional`, defaults to 0)
Beginning of stream token id.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
extra_pos_embeddings: (:obj:`int`, `optional`, defaults to 2): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
How many extra learned positional embeddings to use. Should be set to :obj:`pad_token_id+1`. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
num_labels: (:obj:`int`, `optional`, defaults to 3): scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
The number of labels to use in :class:`~transformers.BartForSequenceClassification`. Scale embeddings by diving by sqrt(d_model).
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether this is an encoder/decoder model.
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only
:obj:`True` for `bart-large-cnn`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
num_labels: (:obj:`int`, `optional`, defaults to 3):
The number of labels to use in :class:`~transformers.BartForSequenceClassification`.
Example::
>>> from transformers import BartModel, BartConfig
>>> # Initializing a BART facebook/bart-large style configuration
>>> configuration = BartConfig()
>>> # Initializing a model from the facebook/bart-large style configuration
>>> model = BartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
""" """
model_type = "bart" model_type = "bart"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
activation_dropout=0.0,
extra_pos_embeddings=2,
activation_function="gelu",
vocab_size=50265, vocab_size=50265,
d_model=1024, max_position_embeddings=1024,
encoder_ffn_dim=4096,
encoder_layers=12, encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16, encoder_attention_heads=16,
decoder_ffn_dim=4096,
decoder_layers=12, decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16, decoder_attention_heads=16,
encoder_layerdrop=0.0, encoder_layerdrop=0.0,
decoder_layerdrop=0.0, decoder_layerdrop=0.0,
attention_dropout=0.0, activation_function="gelu",
d_model=1024,
dropout=0.1, dropout=0.1,
max_position_embeddings=1024, attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02, init_std=0.02,
classifier_dropout=0.0, classifier_dropout=0.0,
num_labels=3,
is_encoder_decoder=True,
normalize_before=False,
add_final_layer_norm=False,
do_blenderbot_90_layernorm=False,
scale_embedding=False, scale_embedding=False,
normalize_embedding=True, gradient_checkpointing=False,
static_position_embeddings=False,
add_bias_logits=False,
force_bos_token_to_be_generated=False, force_bos_token_to_be_generated=False,
use_cache=True, use_cache=True,
num_labels=3,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
**common_kwargs is_encoder_decoder=True,
decoder_start_token_id=2,
**kwargs
): ):
r"""
:class:`~transformers.BartConfig` is the configuration class for `BartModel`.
Examples::
>>> from transformers import BartConfig, BartModel
>>> config = BartConfig.from_pretrained('facebook/bart-large')
>>> model = BartModel(config)
"""
if "hidden_size" in common_kwargs:
raise ValueError("hidden size is called d_model")
super().__init__( super().__init__(
num_labels=num_labels, num_labels=num_labels,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
**common_kwargs, decoder_start_token_id=decoder_start_token_id,
**kwargs,
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads self.encoder_attention_heads = encoder_attention_heads
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.decoder_ffn_dim = decoder_ffn_dim self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads self.decoder_attention_heads = decoder_attention_heads
self.max_position_embeddings = max_position_embeddings self.dropout = dropout
self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
# Params introduced for Mbart
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.normalize_embedding = normalize_embedding # True for mbart, False otherwise
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
self.add_final_layer_norm = add_final_layer_norm
# Params introduced for Marian
self.add_bias_logits = add_bias_logits
self.static_position_embeddings = static_position_embeddings
# 3 Types of Dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout self.activation_dropout = activation_dropout
self.dropout = dropout self.activation_function = activation_function
self.init_std = init_std
# Classifier stuff self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
# pos embedding offset
self.extra_pos_embeddings = extra_pos_embeddings
# bart has a hack that offsets positional embeddings by 2, other models don't do this
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated
self.do_blenderbot_90_layernorm = do_blenderbot_90_layernorm
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 2
self.normalize_before = False
self.add_final_layer_norm = False
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = True
self.static_position_embeddings = False
self.add_bias_logits = False
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
...@@ -220,11 +187,3 @@ class BartConfig(PretrainedConfig): ...@@ -220,11 +187,3 @@ class BartConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def is_valid_mbart(self) -> bool:
"""Is the configuration aligned with the MBART paper."""
if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
return True
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
logger.info("This configuration is a mixture of MBART and BART settings")
return False
This diff is collapsed.
...@@ -545,7 +545,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -545,7 +545,7 @@ BART_INPUTS_DOCSTRING = r"""
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for translation and summarization training. By default, the model will create this tensor by Provide for translation and summarization training. By default, the model will create this tensor by
shifting the input_ids right, following the paper. shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokenizer from .tokenization_blenderbot import BlenderbotTokenizer
if is_torch_available(): if is_torch_available():
...@@ -26,7 +26,9 @@ if is_torch_available(): ...@@ -26,7 +26,9 @@ if is_torch_available():
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForConditionalGeneration, BlenderbotForConditionalGeneration,
BlenderbotModel, BlenderbotModel,
BlenderbotPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
#!/usr/bin/env python3
# coding=utf-8 # coding=utf-8
# Copyright (c) Facebook, Inc. and Huggingface, 2020 # Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved.
# #
# This source code is licensed under the MIT license found in the; # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
...@@ -13,46 +12,49 @@ ...@@ -13,46 +12,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# LICENSE file in the root directory of this source tree. """ Blenderbot model configuration """
"""
BlenderbotConfig has the same signature as BartConfig. We only rewrite the signature in order to document
blenderbot-90M defaults.
"""
from ..bart.configuration_bart import BartConfig
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/blenderbot-3B": "https://cdn.huggingface.co/facebook/blenderbot-3B/config.json", "facebook/blenderbot-3B": "https://huggingface.co/facebook/blenderbot-3B/resolve/main/config.json",
"facebook/blenderbot-90M": "https://cdn.huggingface.co/facebook/blenderbot-90M/config.json", # See all Blenderbot models at https://huggingface.co/models?filter=blenderbot
} }
class BlenderbotConfig(BartConfig): class BlenderbotConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a This is the configuration class to store the configuration of a :class:`~transformers.BlenderbotModel`. It is used
:class:`~transformers.BlenderbotForConditionalGeneration`. It inherits from :class:`~transformers.BartConfig` and to instantiate an Blenderbot model according to the specified arguments, defining the model architecture.
has the same signature with different defaults. Instantiating a configuration with the defaults will yield a similar configuration to that of the Blenderbot
`facebook/blenderbot-3B <https://huggingface.co/facebook/blenderbot-3B>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args: Args:
vocab_size (:obj:`int`, `optional`, defaults to 54944): vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the Vocabulary size of the Blenderbot model. Defines the number of different tokens that can be represented by
:obj:`inputs_ids` passed when calling :class:`~transformers.BlenderbotForConditionalGeneration`. the :obj:`inputs_ids` passed when calling :class:`~transformers.BlenderbotModel` or
d_model (:obj:`int`, `optional`, defaults to 512): :class:`~transformers.TFBlenderbotModel`.
d_model (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
encoder_layers (:obj:`int`, `optional`, defaults to 8): encoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of encoder layers, 6 are used for the `blenderbot-90M` model. Number of encoder layers.
decoder_layers (:obj:`int`, `optional`, defaults to 8): decoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of decoder layers, 6 are used for the `blenderbot-90M` model. Number of decoder layers.
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16): encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048): decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048): encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, The non-linear activation function (function or string) in the encoder and pooler. If string,
...@@ -65,117 +67,115 @@ class BlenderbotConfig(BartConfig): ...@@ -65,117 +67,115 @@ class BlenderbotConfig(BartConfig):
The dropout ratio for activations inside the fully connected layer. The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for classifier. The dropout ratio for classifier.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512): max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, `optional`, defaults to 0.02): init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
add_bias_logits (:obj:`bool`, `optional`, defaults to :obj:`False`):
This should be completed, specific to marian.
normalize_before (:obj:`bool`, `optional`, defaults to :obj:`False`):
Call layernorm before attention ops.
normalize_embedding (:obj:`bool`, `optional`, defaults to :obj:`True`):
Call layernorm after embeddings.
static_position_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Don't learn positional embeddings, use sinusoidal.
add_final_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
Why not add another layernorm?
do_blenderbot_90_layernorm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Blenderbot-90m checkpoint uses `layernorm_embedding` one line earlier in the decoder.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
eos_token_id (:obj:`int`, `optional`, defaults to 2)
End of stream token id.
pad_token_id (:obj:`int`, `optional`, defaults to 1)
Padding token id.
bos_token_id (:obj:`int`, `optional`, defaults to 0)
Beginning of stream token id.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
extra_pos_embeddings: (:obj:`int`, `optional`, defaults to 2): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
How many extra learned positional embeddings to use. Should be set to :obj:`pad_token_id+1`. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`True`): scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether this is an encoder/decoder model. Scale embeddings by diving by sqrt(d_model).
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), Whether or not the model should return the last key/values attentions (not used by all models)
Example::
>>> from transformers import BlenderbotModel, BlenderbotConfig
>>> # Initializing a Blenderbot facebook/blenderbot-3B style configuration
>>> configuration = BlenderbotConfig()
>>> # Initializing a model from the facebook/blenderbot-3B style configuration
>>> model = BlenderbotModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
""" """
model_type = "blenderbot" model_type = "blenderbot"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
activation_dropout=0.0, vocab_size=8008,
extra_pos_embeddings=0, max_position_embeddings=128,
activation_function="gelu", encoder_layers=2,
vocab_size=54944, encoder_ffn_dim=10240,
d_model=512, encoder_attention_heads=32,
encoder_ffn_dim=2048, decoder_layers=24,
encoder_layers=8, decoder_ffn_dim=10240,
encoder_attention_heads=16, decoder_attention_heads=32,
decoder_ffn_dim=2048,
decoder_layers=8,
decoder_attention_heads=16,
encoder_layerdrop=0.0, encoder_layerdrop=0.0,
decoder_layerdrop=0.0, decoder_layerdrop=0.0,
attention_dropout=0.0, use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=2560,
dropout=0.1, dropout=0.1,
max_position_embeddings=512, attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=1,
classifier_dropout=0.0, classifier_dropout=0.0,
is_encoder_decoder=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
normalize_before=False,
add_final_layer_norm=False,
do_blenderbot_90_layernorm=True,
scale_embedding=False, scale_embedding=False,
normalize_embedding=True, gradient_checkpointing=False,
static_position_embeddings=False, pad_token_id=0,
add_bias_logits=False, bos_token_id=1,
force_bos_token_to_be_generated=False, eos_token_id=2,
**common_kwargs **kwargs
): ):
r"""
Examples::
>>> from transformers import BlenderbotConfig
>>> config = BlenderbotConfig.from_pretrained('facebook/blenderbot-90M')
"""
if "hidden_size" in common_kwargs:
raise ValueError("hidden size is called d_model")
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
vocab_size=vocab_size, decoder_start_token_id=decoder_start_token_id,
d_model=d_model, **kwargs,
encoder_ffn_dim=encoder_ffn_dim,
encoder_layers=encoder_layers,
encoder_layerdrop=encoder_layerdrop,
encoder_attention_heads=encoder_attention_heads,
decoder_layerdrop=decoder_layerdrop,
decoder_ffn_dim=decoder_ffn_dim,
decoder_layers=decoder_layers,
normalize_before=normalize_before,
normalize_embedding=normalize_embedding,
static_position_embeddings=static_position_embeddings,
add_bias_logits=add_bias_logits,
force_bos_token_to_be_generated=force_bos_token_to_be_generated,
do_blenderbot_90_layernorm=do_blenderbot_90_layernorm,
add_final_layer_norm=add_final_layer_norm,
scale_embedding=scale_embedding,
attention_dropout=attention_dropout,
dropout=dropout,
classifier_dropout=classifier_dropout,
activation_dropout=activation_dropout,
max_position_embeddings=max_position_embeddings,
extra_pos_embeddings=extra_pos_embeddings,
activation_function=activation_function,
decoder_attention_heads=decoder_attention_heads,
**common_kwargs,
) )
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 0
self.normalize_before = True
self.add_final_layer_norm = True
self.do_blenderbot_90_layernorm = True
self.normalize_embedding = False
self.static_position_embeddings = False
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self) -> int:
return self.d_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment