Unverified Commit 84be482f authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

AutoTokenizer supports mbart-large-en-ro (#5121)

parent 2db1e2f4
...@@ -19,7 +19,7 @@ import logging ...@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
...@@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict( ...@@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict(
("camembert", CamembertConfig,), ("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,), ("xlm-roberta", XLMRobertaConfig,),
("marian", MarianConfig,), ("marian", MarianConfig,),
("mbart", MBartConfig,),
("bart", BartConfig,), ("bart", BartConfig,),
("reformer", ReformerConfig,), ("reformer", ReformerConfig,),
("longformer", LongformerConfig,), ("longformer", LongformerConfig,),
......
...@@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig): ...@@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig):
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding: if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
logger.info("This configuration is a mixture of MBART and BART settings") logger.info("This configuration is a mixture of MBART and BART settings")
return False return False
class MBartConfig(BartConfig):
model_type = "mbart"
...@@ -30,6 +30,7 @@ from .configuration_auto import ( ...@@ -30,6 +30,7 @@ from .configuration_auto import (
FlaubertConfig, FlaubertConfig,
GPT2Config, GPT2Config,
LongformerConfig, LongformerConfig,
MBartConfig,
OpenAIGPTConfig, OpenAIGPTConfig,
ReformerConfig, ReformerConfig,
RetriBertConfig, RetriBertConfig,
...@@ -43,7 +44,7 @@ from .configuration_auto import ( ...@@ -43,7 +44,7 @@ from .configuration_auto import (
from .configuration_marian import MarianConfig from .configuration_marian import MarianConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer from .tokenization_albert import AlbertTokenizer
from .tokenization_bart import BartTokenizer from .tokenization_bart import BartTokenizer, MBartTokenizer
from .tokenization_bert import BertTokenizer, BertTokenizerFast from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
...@@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict( ...@@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)), (AlbertConfig, (AlbertTokenizer, None)),
(CamembertConfig, (CamembertTokenizer, None)), (CamembertConfig, (CamembertTokenizer, None)),
(MBartConfig, (MBartTokenizer, None)),
(XLMRobertaConfig, (XLMRobertaTokenizer, None)), (XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(MarianConfig, (MarianTokenizer, None)), (MarianConfig, (MarianTokenizer, None)),
(BartConfig, (BartTokenizer, None)), (BartConfig, (BartTokenizer, None)),
......
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoModel, AutoModel,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
BartModel, BartModel,
BartForConditionalGeneration, BartForConditionalGeneration,
...@@ -38,7 +39,6 @@ if is_torch_available(): ...@@ -38,7 +39,6 @@ if is_torch_available():
BartForQuestionAnswering, BartForQuestionAnswering,
BartConfig, BartConfig,
BartTokenizer, BartTokenizer,
MBartTokenizer,
BatchEncoding, BatchEncoding,
pipeline, pipeline,
) )
...@@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase): ...@@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
checkpoint_name = "facebook/mbart-large-en-ro" checkpoint_name = "facebook/mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name) cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1 cls.pad_token_id = 1
return cls return cls
@cached_property @cached_property
def model(self): def model(self):
"""Only load the model if needed.""" """Only load the model if needed."""
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
if "cuda" in torch_device: if "cuda" in torch_device:
model = model.half() model = model.half()
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment