Unverified Commit 86578bb0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[AutoModel] Split AutoModelWithLMHead into clm, mlm, encoder-decoder (#4933)

* first commit

* add new auto models

* better naming

* fix bert automodel

* fix automodel for pretraining

* add models to init

* fix name typo

* fix typo

* better naming

* future warning instead of depreciation warning
parent 56200331
......@@ -166,11 +166,17 @@ if is_torch_available():
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelWithLMHead,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoModelForMultipleChoice,
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
......@@ -182,6 +188,7 @@ if is_torch_available():
BertModel,
BertForPreTraining,
BertForMaskedLM,
BertLMHeadModel,
BertForNextSentencePrediction,
BertForSequenceClassification,
BertForMultipleChoice,
......
This diff is collapsed.
......@@ -987,6 +987,9 @@ class BertLMHeadModel(BertPreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert (
not config.is_decoder
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
......
......@@ -32,7 +32,7 @@ class EncoderDecoderModel(PreTrainedModel):
instantiated as a transformer architecture with one of the base model
classes of the library as encoder and another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
......@@ -61,9 +61,9 @@ class EncoderDecoderModel(PreTrainedModel):
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
from transformers import AutoModelWithLMHead
from transformers import AutoModelForCausalLM
decoder = AutoModelWithLMHead.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
......@@ -157,7 +157,7 @@ class EncoderDecoderModel(PreTrainedModel):
assert (
decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModelWithLMHead
from .modeling_auto import AutoModelForCausalLM
if "config" not in kwargs_decoder:
from transformers import AutoConfig
......@@ -176,7 +176,7 @@ class EncoderDecoderModel(PreTrainedModel):
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
return cls(encoder=encoder, decoder=decoder)
......
......@@ -26,13 +26,20 @@ if is_torch_available():
from transformers import (
AutoConfig,
BertConfig,
GPT2Config,
T5Config,
AutoModel,
BertModel,
AutoModelForPreTraining,
BertForPreTraining,
AutoModelForCausalLM,
GPT2LMHeadModel,
AutoModelWithLMHead,
AutoModelForMaskedLM,
BertForMaskedLM,
RobertaForMaskedLM,
AutoModelForSeq2SeqLM,
T5ForConditionalGeneration,
AutoModelForSequenceClassification,
BertForSequenceClassification,
AutoModelForQuestionAnswering,
......@@ -41,6 +48,8 @@ if is_torch_available():
BertForTokenClassification,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_auto import (
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
......@@ -48,6 +57,9 @@ if is_torch_available():
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
)
......@@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForMaskedLM)
@slow
def test_model_for_causal_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, GPT2Config)
model = AutoModelForCausalLM.from_pretrained(model_name)
model, loading_info = AutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, GPT2LMHeadModel)
@slow
def test_model_for_masked_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model, loading_info = AutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForMaskedLM)
@slow
def test_model_for_encoder_decoder_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, T5Config)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model, loading_info = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, T5ForConditionalGeneration)
@slow
def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
......@@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase):
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
)
for mapping in mappings:
......
......@@ -27,6 +27,7 @@ if is_torch_available():
from transformers import (
BertConfig,
BertModel,
BertLMHeadModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
......@@ -35,7 +36,7 @@ if is_torch_available():
BertForTokenClassification,
BertForMultipleChoice,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
class BertModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment