Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b7141a1b
Commit
b7141a1b
authored
Oct 14, 2019
by
thomwolf
Browse files
maxi simplication
parent
bfbe68f0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
72 deletions
+3
-72
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+3
-72
No files found.
transformers/modeling_seq2seq.py
View file @
b7141a1b
...
...
@@ -21,14 +21,7 @@ import logging
import
torch
from
torch
import
nn
from
.modeling_bert
import
BertModel
,
BertForMaskedLM
,
BertForSequenceClassification
,
BertForQuestionAnswering
from
.modeling_openai
import
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
from
.modeling_gpt2
import
GPT2Model
,
GPT2LMHeadModel
from
.modeling_transfo_xl
import
TransfoXLModel
,
TransfoXLLMHeadModel
from
.modeling_xlnet
import
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
from
.modeling_xlm
import
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
...
@@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module):
that will be instantiated as a Seq2seq model with one of the base model classes of the library
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
def
__init__
(
self
,
encoder
,
decoder
):
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
...
...
@@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
...
...
@@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module):
else
:
# Load and initialize the encoder
kwargs
[
'is_decoder'
]
=
False
# Make sure the encoder will be an encoder
if
'distilbert'
in
pretrained_model_name_or_path
:
encoder
=
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
encoder
=
RobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
encoder
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'openai-gpt'
in
pretrained_model_name_or_path
:
encoder
=
OpenAIGPTModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'gpt2'
in
pretrained_model_name_or_path
:
encoder
=
GPT2Model
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'transfo-xl'
in
pretrained_model_name_or_path
:
encoder
=
TransfoXLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
encoder
=
XLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm'
in
pretrained_model_name_or_path
:
encoder
=
XLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
else
:
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
encoder
=
AutoModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
# Load and initialize the decoder
if
decoder_model
:
...
...
@@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module):
else
:
kwargs
.
update
(
decoder_kwargs
)
# Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs
[
'is_decoder'
]
=
True
# Make sure the decoder will be an decoder
if
'distilbert'
in
decoder_pretrained_model_name_or_path
:
decoder
=
DistilBertModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
decoder_pretrained_model_name_or_path
:
decoder
=
RobertaModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'bert'
in
decoder_pretrained_model_name_or_path
:
decoder
=
BertModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'openai-gpt'
in
decoder_pretrained_model_name_or_path
:
decoder
=
OpenAIGPTModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'gpt2'
in
decoder_pretrained_model_name_or_path
:
decoder
=
GPT2Model
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'transfo-xl'
in
decoder_pretrained_model_name_or_path
:
decoder
=
TransfoXLModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'xlnet'
in
decoder_pretrained_model_name_or_path
:
decoder
=
XLNetModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'xlm'
in
decoder_pretrained_model_name_or_path
:
decoder
=
XLMModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
else
:
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment