Commit b7141a1b authored by thomwolf's avatar thomwolf
Browse files

maxi simplication

parent bfbe68f0
...@@ -21,14 +21,7 @@ import logging ...@@ -21,14 +21,7 @@ import logging
import torch import torch
from torch import nn from torch import nn
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering from .modeling_auto import AutoModel, AutoModelWithLMHead
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module):
that will be instantiated as a Seq2seq model with one of the base model classes of the library that will be instantiated as a Seq2seq model with one of the base model classes of the library
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__() super(PreTrainedSeq2seq, self).__init__()
...@@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module):
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library r""" Instantiates one of the base model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()` To train the model, you should first set it back in training mode with `model.train()`
...@@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module):
else: else:
# Load and initialize the encoder # Load and initialize the encoder
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
if 'distilbert' in pretrained_model_name_or_path: encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
# Load and initialize the decoder # Load and initialize the decoder
if decoder_model: if decoder_model:
...@@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module):
else: else:
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc... kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
if 'distilbert' in decoder_pretrained_model_name_or_path: decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'roberta' in decoder_pretrained_model_name_or_path:
decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'bert' in decoder_pretrained_model_name_or_path:
decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'openai-gpt' in decoder_pretrained_model_name_or_path:
decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'gpt2' in decoder_pretrained_model_name_or_path:
decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'transfo-xl' in decoder_pretrained_model_name_or_path:
decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'xlnet' in decoder_pretrained_model_name_or_path:
decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'xlm' in decoder_pretrained_model_name_or_path:
decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
else: else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment