Unverified Commit 86578bb0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[AutoModel] Split AutoModelWithLMHead into clm, mlm, encoder-decoder (#4933)

* first commit

* add new auto models

* better naming

* fix bert automodel

* fix automodel for pretraining

* add models to init

* fix name typo

* fix typo

* better naming

* future warning instead of depreciation warning
parent 56200331
...@@ -166,11 +166,17 @@ if is_torch_available(): ...@@ -166,11 +166,17 @@ if is_torch_available():
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelWithLMHead, AutoModelWithLMHead,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
...@@ -182,6 +188,7 @@ if is_torch_available(): ...@@ -182,6 +188,7 @@ if is_torch_available():
BertModel, BertModel,
BertForPreTraining, BertForPreTraining,
BertForMaskedLM, BertForMaskedLM,
BertLMHeadModel,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForSequenceClassification, BertForSequenceClassification,
BertForMultipleChoice, BertForMultipleChoice,
......
This diff is collapsed.
...@@ -987,6 +987,9 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -987,6 +987,9 @@ class BertLMHeadModel(BertPreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert (
not config.is_decoder
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
self.bert = BertModel(config) self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config) self.cls = BertOnlyMLMHead(config)
......
...@@ -32,7 +32,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -32,7 +32,7 @@ class EncoderDecoderModel(PreTrainedModel):
instantiated as a transformer architecture with one of the base model instantiated as a transformer architecture with one of the base model
classes of the library as encoder and another one as classes of the library as encoder and another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
""" """
config_class = EncoderDecoderConfig config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
...@@ -61,9 +61,9 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -61,9 +61,9 @@ class EncoderDecoderModel(PreTrainedModel):
encoder = AutoModel.from_config(config.encoder) encoder = AutoModel.from_config(config.encoder)
if decoder is None: if decoder is None:
from transformers import AutoModelWithLMHead from transformers import AutoModelForCausalLM
decoder = AutoModelWithLMHead.from_config(config.decoder) decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
...@@ -157,7 +157,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -157,7 +157,7 @@ class EncoderDecoderModel(PreTrainedModel):
assert ( assert (
decoder_pretrained_model_name_or_path is not None decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModelWithLMHead from .modeling_auto import AutoModelForCausalLM
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
from transformers import AutoConfig from transformers import AutoConfig
...@@ -176,7 +176,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -176,7 +176,7 @@ class EncoderDecoderModel(PreTrainedModel):
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
) )
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
return cls(encoder=encoder, decoder=decoder) return cls(encoder=encoder, decoder=decoder)
......
...@@ -26,13 +26,20 @@ if is_torch_available(): ...@@ -26,13 +26,20 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
BertConfig, BertConfig,
GPT2Config,
T5Config,
AutoModel, AutoModel,
BertModel, BertModel,
AutoModelForPreTraining, AutoModelForPreTraining,
BertForPreTraining, BertForPreTraining,
AutoModelForCausalLM,
GPT2LMHeadModel,
AutoModelWithLMHead, AutoModelWithLMHead,
AutoModelForMaskedLM,
BertForMaskedLM, BertForMaskedLM,
RobertaForMaskedLM, RobertaForMaskedLM,
AutoModelForSeq2SeqLM,
T5ForConditionalGeneration,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
BertForSequenceClassification, BertForSequenceClassification,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
...@@ -41,6 +48,8 @@ if is_torch_available(): ...@@ -41,6 +48,8 @@ if is_torch_available():
BertForTokenClassification, BertForTokenClassification,
) )
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_auto import ( from transformers.modeling_auto import (
MODEL_MAPPING, MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
...@@ -48,6 +57,9 @@ if is_torch_available(): ...@@ -48,6 +57,9 @@ if is_torch_available():
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
) )
...@@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase): ...@@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
@slow
def test_model_for_causal_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, GPT2Config)
model = AutoModelForCausalLM.from_pretrained(model_name)
model, loading_info = AutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, GPT2LMHeadModel)
@slow
def test_model_for_masked_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model, loading_info = AutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForMaskedLM)
@slow
def test_model_for_encoder_decoder_lm(self):
logging.basicConfig(level=logging.INFO)
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, T5Config)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model, loading_info = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, T5ForConditionalGeneration)
@slow @slow
def test_sequence_classification_model_from_pretrained(self): def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase): ...@@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase):
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
) )
for mapping in mappings: for mapping in mappings:
......
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
from transformers import ( from transformers import (
BertConfig, BertConfig,
BertModel, BertModel,
BertLMHeadModel,
BertForMaskedLM, BertForMaskedLM,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForPreTraining, BertForPreTraining,
...@@ -35,7 +36,7 @@ if is_torch_available(): ...@@ -35,7 +36,7 @@ if is_torch_available():
BertForTokenClassification, BertForTokenClassification,
BertForMultipleChoice, BertForMultipleChoice,
) )
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
class BertModelTester: class BertModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment