Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9b71fc9a
Commit
9b71fc9a
authored
Oct 16, 2019
by
Rémi Louf
Browse files
tying weights is going to be a clusterfuck
parent
95ec1d08
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
25 deletions
+56
-25
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+56
-25
No files found.
transformers/modeling_seq2seq.py
View file @
9b71fc9a
...
@@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
...
@@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
PreTrainedSeq2seq
(
nn
.
Module
):
class
PreTrainedSeq2seq
(
PreTrainedModel
):
r
"""
r
"""
:class:`~transformers.Seq2seq` is a generic model class that will be
:class:`~transformers.Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
instantiated as a Seq2seq model with one of the base model classes of
...
@@ -36,13 +36,20 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -36,13 +36,20 @@ class PreTrainedSeq2seq(nn.Module):
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
method.
method.
"""
"""
def
__init__
(
self
,
encoder
,
decoder
):
def
__init__
(
self
,
encoder
,
decoder
):
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
self
.
encoder
=
encoder
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
decoder
=
decoder
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
""" Instantiates an encoder and a decoder from one or two base classes
r
""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
of the library from pre-trained model checkpoints.
...
@@ -110,21 +117,25 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -110,21 +117,25 @@ class PreTrainedSeq2seq(nn.Module):
kwargs_decoder
=
{}
kwargs_decoder
=
{}
kwargs_encoder
=
kwargs
kwargs_encoder
=
kwargs
for
key
in
kwargs_encoder
.
keys
():
for
key
in
kwargs_encoder
.
keys
():
if
key
.
startswith
(
'
decoder_
'
):
if
key
.
startswith
(
"
decoder_
"
):
kwargs_decoder
[
key
.
replace
(
'
decoder_
'
,
''
)]
=
kwargs_encoder
.
pop
(
key
)
kwargs_decoder
[
key
.
replace
(
"
decoder_
"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
# Load and initialize the encoder and decoder
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
# by the value of the flag `is_decoder` that we need to set correctly.
encoder
=
kwargs
.
pop
(
'
encoder_model
'
,
None
)
encoder
=
kwargs
.
pop
(
"
encoder_model
"
,
None
)
if
encoder
is
None
:
if
encoder
is
None
:
kwargs_encoder
[
'is_decoder'
]
=
False
kwargs_encoder
[
"is_decoder"
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
decoder
=
kwargs
.
pop
(
'
decoder_model
'
,
None
)
decoder
=
kwargs
.
pop
(
"
decoder_model
"
,
None
)
if
decoder
is
None
:
if
decoder
is
None
:
kwargs_decoder
[
'is_decoder'
]
=
True
kwargs_decoder
[
"is_decoder"
]
=
True
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
model
=
cls
(
encoder
,
decoder
)
model
=
cls
(
encoder
,
decoder
)
...
@@ -153,11 +164,11 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -153,11 +164,11 @@ class PreTrainedSeq2seq(nn.Module):
kwargs_decoder
=
{}
kwargs_decoder
=
{}
kwargs_encoder
=
kwargs
kwargs_encoder
=
kwargs
for
key
in
kwargs_encoder
.
keys
():
for
key
in
kwargs_encoder
.
keys
():
if
key
.
startswith
(
'
decoder_
'
):
if
key
.
startswith
(
"
decoder_
"
):
kwargs_decoder
[
key
.
replace
(
'
decoder_
'
,
''
)]
=
kwargs_encoder
.
pop
(
key
)
kwargs_decoder
[
key
.
replace
(
"
decoder_
"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
'
encoder_hidden_states
'
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"
encoder_hidden_states
"
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
encoder_hidden_states
=
encoder_outputs
[
0
]
...
@@ -165,29 +176,49 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -165,29 +176,49 @@ class PreTrainedSeq2seq(nn.Module):
encoder_outputs
=
()
encoder_outputs
=
()
# Decode
# Decode
kwargs_decoder
[
'
encoder_hidden_states
'
]
=
encoder_hidden_states
kwargs_decoder
[
"
encoder_hidden_states
"
]
=
encoder_hidden_states
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
class
Model2Model
(
PreTrainedSeq2seq
):
class
Model2Model
(
PreTrainedSeq2seq
):
def
tie_weights
():
def
__init__
(
self
):
# We should tie encoder and decoder embeddings if possible here
super
(
Model2Model
,
self
).
__init__
()
pass
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Tying the encoder and decoders' embeddings together.
We need for each to get down to the embedding weights. However the
different model classes are inconsistent to that respect:
- BertModel: embeddings.word_embeddings
- RoBERTa: embeddings.word_embeddings
- XLMModel: embeddings
- GPT2: wte
- BertForMaskedLM: bert.embeddings.word_embeddings
- RobertaForMaskedLM: roberta.embeddings.word_embeddings
argument of the XEmbedding layer for each model, but it is "blocked"
by a model-specific keyword (bert, )...
"""
# self._tie_or_clone_weights(self.encoder, self.decoder)
raise
NotImplementedError
class
Model2LSTM
(
PreTrainedSeq2seq
):
class
Model2LSTM
(
PreTrainedSeq2seq
):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
if
kwargs
.
get
(
'
decoder_model
'
,
None
)
is
None
:
if
kwargs
.
get
(
"
decoder_model
"
,
None
)
is
None
:
# We will create a randomly initilized LSTM model as decoder
# We will create a randomly initilized LSTM model as decoder
if
'decoder_config'
not
in
kwargs
:
if
"decoder_config"
not
in
kwargs
:
raise
ValueError
(
"To load an LSTM in Seq2seq model, please supply either: "
raise
ValueError
(
"To load an LSTM in Seq2seq model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a dictionary of configuration parameters that will be used to initialize a"
" - a dictionary of configuration parameters that will be used to initialize a"
" torch.nn.LSTM model as `decoder_config` keyword argument. "
" torch.nn.LSTM model as `decoder_config` keyword argument. "
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
)
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
kwargs
[
'decoder_model'
]
=
torch
.
nn
.
LSTM
(
kwargs
.
pop
(
'decoder_config'
))
)
kwargs
[
"decoder_model"
]
=
torch
.
nn
.
LSTM
(
kwargs
.
pop
(
"decoder_config"
))
model
=
super
(
Model2LSTM
,
cls
).
from_pretrained
(
*
args
,
**
kwargs
)
model
=
super
(
Model2LSTM
,
cls
).
from_pretrained
(
*
args
,
**
kwargs
)
return
model
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment