Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b7141a1b
Commit
b7141a1b
authored
Oct 14, 2019
by
thomwolf
Browse files
maxi simplication
parent
bfbe68f0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
72 deletions
+3
-72
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+3
-72
No files found.
transformers/modeling_seq2seq.py
View file @
b7141a1b
...
@@ -21,14 +21,7 @@ import logging
...
@@ -21,14 +21,7 @@ import logging
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.modeling_bert
import
BertModel
,
BertForMaskedLM
,
BertForSequenceClassification
,
BertForQuestionAnswering
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_openai
import
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
from
.modeling_gpt2
import
GPT2Model
,
GPT2LMHeadModel
from
.modeling_transfo_xl
import
TransfoXLModel
,
TransfoXLLMHeadModel
from
.modeling_xlnet
import
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
from
.modeling_xlm
import
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
@@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -43,22 +36,6 @@ class PreTrainedSeq2seq(nn.Module):
that will be instantiated as a Seq2seq model with one of the base model classes of the library
that will be instantiated as a Seq2seq model with one of the base model classes of the library
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
def
__init__
(
self
,
encoder
,
decoder
):
def
__init__
(
self
,
encoder
,
decoder
):
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
...
@@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -69,18 +46,6 @@ class PreTrainedSeq2seq(nn.Module):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
""" Instantiates one of the base model classes of the library
r
""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
from a pre-trained model configuration.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
To train the model, you should first set it back in training mode with `model.train()`
...
@@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -155,26 +120,7 @@ class PreTrainedSeq2seq(nn.Module):
else
:
else
:
# Load and initialize the encoder
# Load and initialize the encoder
kwargs
[
'is_decoder'
]
=
False
# Make sure the encoder will be an encoder
kwargs
[
'is_decoder'
]
=
False
# Make sure the encoder will be an encoder
if
'distilbert'
in
pretrained_model_name_or_path
:
encoder
=
AutoModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
encoder
=
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
encoder
=
RobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
encoder
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'openai-gpt'
in
pretrained_model_name_or_path
:
encoder
=
OpenAIGPTModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'gpt2'
in
pretrained_model_name_or_path
:
encoder
=
GPT2Model
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'transfo-xl'
in
pretrained_model_name_or_path
:
encoder
=
TransfoXLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
encoder
=
XLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm'
in
pretrained_model_name_or_path
:
encoder
=
XLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
else
:
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
# Load and initialize the decoder
# Load and initialize the decoder
if
decoder_model
:
if
decoder_model
:
...
@@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -182,22 +128,7 @@ class PreTrainedSeq2seq(nn.Module):
else
:
else
:
kwargs
.
update
(
decoder_kwargs
)
# Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs
.
update
(
decoder_kwargs
)
# Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs
[
'is_decoder'
]
=
True
# Make sure the decoder will be an decoder
kwargs
[
'is_decoder'
]
=
True
# Make sure the decoder will be an decoder
if
'distilbert'
in
decoder_pretrained_model_name_or_path
:
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
decoder
=
DistilBertModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
decoder_pretrained_model_name_or_path
:
decoder
=
RobertaModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'bert'
in
decoder_pretrained_model_name_or_path
:
decoder
=
BertModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'openai-gpt'
in
decoder_pretrained_model_name_or_path
:
decoder
=
OpenAIGPTModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'gpt2'
in
decoder_pretrained_model_name_or_path
:
decoder
=
GPT2Model
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'transfo-xl'
in
decoder_pretrained_model_name_or_path
:
decoder
=
TransfoXLModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'xlnet'
in
decoder_pretrained_model_name_or_path
:
decoder
=
XLNetModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
elif
'xlm'
in
decoder_pretrained_model_name_or_path
:
decoder
=
XLMModel
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment