"tests/vscode:/vscode.git/clone" did not exist on "681183974acb987d4bb7037de3e82078bd0308a1"
Unverified Commit 7f286132 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TFBart] Split TF-Bart (#9497)

* make templates ready

* make add_new_model_command_ready

* finish tf bart

* prepare tf mbart

* finish tf bart

* add tf mbart

* add marian

* prep pegasus

* add tf pegasus

* push blenderbot tf

* add blenderbot

* add blenderbot small

* clean-up

* make fix copy

* define blend bot tok

* fix

* up

* make style

* add to docs

* add copy statements

* overwrite changes

* improve

* fix docs

* finish

* fix last slow test

* fix missing git conflict line

* fix blenderbot

* up

* fix blenderbot small

* load changes

* finish copied from

* upload fix
parent 0ecbb698
...@@ -225,7 +225,7 @@ TensorFlow and/or Flax. ...@@ -225,7 +225,7 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ | | Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall | ✅ | ❌ | ✅ | | ❌ | | BlenderbotSmall | ✅ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -98,10 +98,15 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` ...@@ -98,10 +98,15 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
:members: forward :members: forward
TFBlenderbotForConditionalGeneration TFBlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate` .. autoclass:: transformers.TFBlenderbotModel
:members: call
TFBlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration .. autoclass:: transformers.TFBlenderbotForConditionalGeneration
:members: :members: call
...@@ -68,3 +68,17 @@ BlenderbotSmallForConditionalGeneration ...@@ -68,3 +68,17 @@ BlenderbotSmallForConditionalGeneration
.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration .. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
:members: forward :members: forward
TFBlenderbotSmallModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBlenderbotSmallModel
:members: call
TFBlenderbotSmallForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBlenderbotSmallForConditionalGeneration
:members: call
...@@ -193,7 +193,15 @@ MarianMTModel ...@@ -193,7 +193,15 @@ MarianMTModel
:members: forward :members: forward
TFMarianModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMarianModel
:members: call
TFMarianMTModel TFMarianMTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMarianMTModel .. autoclass:: transformers.TFMarianMTModel
:members: call
...@@ -124,8 +124,15 @@ MBartForSequenceClassification ...@@ -124,8 +124,15 @@ MBartForSequenceClassification
.. autoclass:: transformers.MBartForSequenceClassification .. autoclass:: transformers.MBartForSequenceClassification
TFMBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMBartModel
:members: call
TFMBartForConditionalGeneration TFMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMBartForConditionalGeneration .. autoclass:: transformers.TFMBartForConditionalGeneration
:members: :members: call
...@@ -131,7 +131,15 @@ PegasusForConditionalGeneration ...@@ -131,7 +131,15 @@ PegasusForConditionalGeneration
:members: forward :members: forward
TFPegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFPegasusModel
:members: call
TFPegasusForConditionalGeneration TFPegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFPegasusForConditionalGeneration .. autoclass:: transformers.TFPegasusForConditionalGeneration
:members: call
...@@ -868,7 +868,10 @@ if is_tf_available(): ...@@ -868,7 +868,10 @@ if is_tf_available():
"TFBertPreTrainedModel", "TFBertPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot"].append("TFBlenderbotForConditionalGeneration") _import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
)
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
[ [
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -986,8 +989,8 @@ if is_tf_available(): ...@@ -986,8 +989,8 @@ if is_tf_available():
"TFLxmertVisualFeatureEncoder", "TFLxmertVisualFeatureEncoder",
] ]
) )
_import_structure["models.marian"].append("TFMarianMTModel") _import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
_import_structure["models.mbart"].append("TFMBartForConditionalGeneration") _import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
_import_structure["models.mobilebert"].extend( _import_structure["models.mobilebert"].extend(
[ [
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1028,7 +1031,7 @@ if is_tf_available(): ...@@ -1028,7 +1031,7 @@ if is_tf_available():
"TFOpenAIGPTPreTrainedModel", "TFOpenAIGPTPreTrainedModel",
] ]
) )
_import_structure["models.pegasus"].append("TFPegasusForConditionalGeneration") _import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1855,7 +1858,8 @@ if TYPE_CHECKING: ...@@ -1855,7 +1858,8 @@ if TYPE_CHECKING:
TFBertModel, TFBertModel,
TFBertPreTrainedModel, TFBertPreTrainedModel,
) )
from .models.blenderbot import TFBlenderbotForConditionalGeneration from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .models.camembert import ( from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
...@@ -1953,8 +1957,8 @@ if TYPE_CHECKING: ...@@ -1953,8 +1957,8 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel, TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder, TFLxmertVisualFeatureEncoder,
) )
from .models.marian import TFMarianMTModel from .models.marian import TFMarian, TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
from .models.mobilebert import ( from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
...@@ -1989,7 +1993,7 @@ if TYPE_CHECKING: ...@@ -1989,7 +1993,7 @@ if TYPE_CHECKING:
TFOpenAIGPTModel, TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel, TFOpenAIGPTPreTrainedModel,
) )
from .models.pegasus import TFPegasusForConditionalGeneration from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .models.roberta import ( from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
......
...@@ -44,7 +44,11 @@ from ..bert.modeling_tf_bert import ( ...@@ -44,7 +44,11 @@ from ..bert.modeling_tf_bert import (
TFBertLMHeadModel, TFBertLMHeadModel,
TFBertModel, TFBertModel,
) )
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from ..blenderbot_small.modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
)
from ..camembert.modeling_tf_camembert import ( from ..camembert.modeling_tf_camembert import (
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
TFCamembertForMultipleChoice, TFCamembertForMultipleChoice,
...@@ -100,8 +104,8 @@ from ..longformer.modeling_tf_longformer import ( ...@@ -100,8 +104,8 @@ from ..longformer.modeling_tf_longformer import (
TFLongformerModel, TFLongformerModel,
) )
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from ..marian.modeling_tf_marian import TFMarianMTModel from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from ..mobilebert.modeling_tf_mobilebert import ( from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice, TFMobileBertForMultipleChoice,
...@@ -122,7 +126,7 @@ from ..mpnet.modeling_tf_mpnet import ( ...@@ -122,7 +126,7 @@ from ..mpnet.modeling_tf_mpnet import (
) )
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from ..roberta.modeling_tf_roberta import ( from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
...@@ -167,6 +171,7 @@ from .configuration_auto import ( ...@@ -167,6 +171,7 @@ from .configuration_auto import (
BartConfig, BartConfig,
BertConfig, BertConfig,
BlenderbotConfig, BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
DistilBertConfig, DistilBertConfig,
...@@ -225,6 +230,12 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -225,6 +230,12 @@ TF_MODEL_MAPPING = OrderedDict(
(FunnelConfig, TFFunnelModel), (FunnelConfig, TFFunnelModel),
(DPRConfig, TFDPRQuestionEncoder), (DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel), (MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
(MBartConfig, TFMBartModel),
(MarianConfig, TFMarianModel),
(PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
] ]
) )
...@@ -328,6 +339,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -328,6 +339,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
(MBartConfig, TFMBartForConditionalGeneration), (MBartConfig, TFMBartForConditionalGeneration),
(PegasusConfig, TFPegasusForConditionalGeneration), (PegasusConfig, TFPegasusForConditionalGeneration),
(BlenderbotConfig, TFBlenderbotForConditionalGeneration), (BlenderbotConfig, TFBlenderbotForConditionalGeneration),
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
(BartConfig, TFBartForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration),
] ]
) )
......
...@@ -24,6 +24,7 @@ from ..bart.tokenization_bart import BartTokenizer ...@@ -24,6 +24,7 @@ from ..bart.tokenization_bart import BartTokenizer
from ..bert.tokenization_bert import BertTokenizer from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer from ..deberta.tokenization_deberta import DebertaTokenizer
...@@ -58,6 +59,7 @@ from .configuration_auto import ( ...@@ -58,6 +59,7 @@ from .configuration_auto import (
BertConfig, BertConfig,
BertGenerationConfig, BertGenerationConfig,
BlenderbotConfig, BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
DebertaConfig, DebertaConfig,
...@@ -201,7 +203,8 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -201,7 +203,8 @@ TOKENIZER_MAPPING = OrderedDict(
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)), (MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)), (XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)), (MarianConfig, (MarianTokenizer, None)),
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)), (BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotConfig, (BlenderbotTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)), (BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
......
...@@ -170,16 +170,6 @@ class BartConfig(PretrainedConfig): ...@@ -170,16 +170,6 @@ class BartConfig(PretrainedConfig):
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 2
self.normalize_before = False
self.add_final_layer_norm = False
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = True
self.static_position_embeddings = False
self.add_bias_logits = False
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -36,7 +36,7 @@ if is_torch_available(): ...@@ -36,7 +36,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration"] _import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -52,7 +52,7 @@ if TYPE_CHECKING: ...@@ -52,7 +52,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
else: else:
import importlib import importlib
......
...@@ -161,17 +161,6 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -161,17 +161,6 @@ class BlenderbotConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 0
self.normalize_before = True
self.add_final_layer_norm = True
self.do_blenderbot_90_layernorm = True
self.normalize_embedding = False
self.static_position_embeddings = False
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_torch_available from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -33,6 +33,11 @@ if is_torch_available(): ...@@ -33,6 +33,11 @@ if is_torch_available():
"BlenderbotSmallPreTrainedModel", "BlenderbotSmallPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig
...@@ -46,6 +51,9 @@ if TYPE_CHECKING: ...@@ -46,6 +51,9 @@ if TYPE_CHECKING:
BlenderbotSmallPreTrainedModel, BlenderbotSmallPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
else: else:
import importlib import importlib
import os import os
......
...@@ -866,6 +866,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -866,6 +866,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
......
...@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__) ...@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json", "vocab_file": "vocab.json",
"merges_file": "merges.txt", "merges_file": "merges.txt",
# "tokenizer_config_file": "tokenizer_config.json", "tokenizer_config_file": "tokenizer_config.json",
} }
...@@ -75,13 +75,20 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): ...@@ -75,13 +75,20 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer` Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
""" """
vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} vocab_files_names = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"tokenizer_config": "tokenizer_config.json",
}
pretrained_vocab_files_map = { pretrained_vocab_files_map = {
"vocab_file": { "vocab_file": {
"facebook/blenderbot_small-90M": "https://cdn.huggingface.co/facebook/blenderbot_small-90M/vocab.json" "facebook/blenderbot_small-90M": "https://huggingface.co/facebook/blenderbot_small-90M/blob/main/vocab.json"
}, },
"merges_file": { "merges_file": {
"facebook/blenderbot_small-90M": "https://cdn.huggingface.co/facebook/blenderbot_small-90M/merges.txt" "facebook/blenderbot_small-90M": "https://huggingface.co/facebook/blenderbot_small-90M/blob/main/merges.txt"
},
"tokenizer_config_file": {
"facebook/blenderbot_small-90M": "https://huggingface.co/facebook/blenderbot_small-90M/blob/main/tokenizer.json"
}, },
} }
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512} max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
......
...@@ -1475,7 +1475,7 @@ LED_INPUTS_DOCSTRING = r""" ...@@ -1475,7 +1475,7 @@ LED_INPUTS_DOCSTRING = r"""
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail. more detail.
return_dict (:obj:`bool`, `optional`): return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
training (:obj:`bool`, `optional`, defaults to :obj:`False`): training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
......
...@@ -42,7 +42,7 @@ if is_torch_available(): ...@@ -42,7 +42,7 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianMTModel"] _import_structure["modeling_tf_marian"] = ["TFMarianMTModel", "TFMarianModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -60,7 +60,7 @@ if TYPE_CHECKING: ...@@ -60,7 +60,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_marian import TFMarianMTModel from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
else: else:
import importlib import importlib
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment