Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d6c3267
Commit
6d6c3267
authored
Oct 15, 2019
by
Rémi Louf
Browse files
take path to pretrained for encoder and decoder for init
parent
0d81fc85
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
33 deletions
+26
-33
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+26
-33
No files found.
transformers/modeling_seq2seq.py
View file @
6d6c3267
...
...
@@ -21,21 +21,20 @@ import logging
import
torch
from
torch
import
nn
from
.file_utils
import
add_start_docstrings
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
class
PreTrainedSeq2seq
(
nn
.
Module
):
r
"""
:class:`~transformers.Seq2seq` is a generic model class
that will be instantiated as a Seq2seq model with one of the base model classes of the library
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
:class:`~transformers.Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
method.
"""
def
__init__
(
self
,
encoder
,
decoder
):
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
...
...
@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
self
.
decoder
=
decoder
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
,
decoder_
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
...
...
@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
# Extract encoder and decoder model if provided
encoder_model
=
kwargs
.
pop
(
'encoder_model'
,
None
)
decoder_model
=
kwargs
.
pop
(
'decoder_model'
,
None
)
#
Extract decoder kwargs so we only have encoder kwargs for now
if
decoder
_model
is
None
:
decoder_pretrained_model_name_or_path
=
kwargs
.
pop
(
'decoder_pretrained_model_name_or_path'
,
pretrained_model_name_or_path
)
de
coder
_
kwargs
=
{}
for
key
in
kwargs
.
keys
():
#
Separate the encoder- and decoder- specific kwargs. A kwarg is
#
decoder
-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_en
coder
=
kwargs
for
key
in
kwargs
_encoder
.
keys
():
if
key
.
startswith
(
'decoder_'
):
decoder
_kwargs
[
key
.
replace
(
'decoder_'
,
''
)]
=
kwargs
.
pop
(
key
)
kwargs_
decoder
[
key
.
replace
(
'decoder_'
,
''
)]
=
kwargs
_encoder
.
pop
(
key
)
# Load and initialize the decoder
if
encoder_model
:
encoder
=
encoder_model
e
lse
:
# Load and initialize the encoder
kwargs
[
'is_decoder'
]
=
False
# Make sure the encoder will be an encoder
encoder
=
AutoModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
# Load and initialize the
encoder and
decoder
# The distinction between encoder and decoder at the model level is made
#
by the value of the flag `is_decoder` that we need to set correctly.
e
ncoder
=
kwargs
.
pop
(
'encoder_model'
,
None
)
if
encoder
is
None
:
kwargs
_encoder
[
'is_decoder'
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
_encoder
)
# Load and initialize the decoder
if
decoder_model
:
decoder
=
decoder_model
else
:
kwargs
.
update
(
decoder_kwargs
)
# Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs
[
'is_decoder'
]
=
True
# Make sure the decoder will be a decoder
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
decoder
=
kwargs
.
pop
(
'decoder_model'
,
None
)
if
decoder
is
None
:
kwargs_decoder
[
'is_decoder'
]
=
True
decoder_model
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
model
=
cls
(
encoder
,
decoder
)
return
model
def
forward
(
self
,
*
inputs
,
*
kwargs
):
def
forward
(
self
,
*
inputs
,
*
*
kwargs
):
# Extract decoder inputs
decoder_kwargs
=
{}
for
key
in
kwargs
.
keys
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment