Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d6c3267
Commit
6d6c3267
authored
Oct 15, 2019
by
Rémi Louf
Browse files
take path to pretrained for encoder and decoder for init
parent
0d81fc85
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
33 deletions
+26
-33
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+26
-33
No files found.
transformers/modeling_seq2seq.py
View file @
6d6c3267
...
@@ -21,21 +21,20 @@ import logging
...
@@ -21,21 +21,20 @@ import logging
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.file_utils
import
add_start_docstrings
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
PreTrainedSeq2seq
(
nn
.
Module
):
class
PreTrainedSeq2seq
(
nn
.
Module
):
r
"""
r
"""
:class:`~transformers.Seq2seq` is a generic model class
:class:`~transformers.Seq2seq` is a generic model class that will be
that will be instantiated as a Seq2seq model with one of the base model classes of the library
instantiated as a Seq2seq model with one of the base model classes of
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
the library as encoder and (optionally) as decoder when created with
class method.
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
method.
"""
"""
def
__init__
(
self
,
encoder
,
decoder
):
def
__init__
(
self
,
encoder
,
decoder
):
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
super
(
PreTrainedSeq2seq
,
self
).
__init__
()
...
@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
self
.
decoder
=
decoder
self
.
decoder
=
decoder
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
,
decoder_
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
""" Instantiates one of the base model classes of the library
r
""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
...
@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
"""
# Extract encoder and decoder model if provided
encoder_model
=
kwargs
.
pop
(
'encoder_model'
,
None
)
decoder_model
=
kwargs
.
pop
(
'decoder_model'
,
None
)
#
Extract decoder kwargs so we only have encoder kwargs for now
#
Separate the encoder- and decoder- specific kwargs. A kwarg is
if
decoder
_model
is
None
:
#
decoder
-specific it the key starts with `decoder_`
decoder_pretrained_model_name_or_path
=
kwargs
.
pop
(
'decoder_pretrained_model_name_or_path'
,
pretrained_model_name_or_path
)
kwargs_decoder
=
{}
de
coder
_
kwargs
=
{}
kwargs_en
coder
=
kwargs
for
key
in
kwargs
.
keys
():
for
key
in
kwargs
_encoder
.
keys
():
if
key
.
startswith
(
'decoder_'
):
if
key
.
startswith
(
'decoder_'
):
decoder
_kwargs
[
key
.
replace
(
'decoder_'
,
''
)]
=
kwargs
.
pop
(
key
)
kwargs_
decoder
[
key
.
replace
(
'decoder_'
,
''
)]
=
kwargs
_encoder
.
pop
(
key
)
# Load and initialize the decoder
# Load and initialize the
encoder and
decoder
if
encoder_model
:
# The distinction between encoder and decoder at the model level is made
encoder
=
encoder_model
#
by the value of the flag `is_decoder` that we need to set correctly.
e
lse
:
e
ncoder
=
kwargs
.
pop
(
'encoder_model'
,
None
)
# Load and initialize the encoder
if
encoder
is
None
:
kwargs
[
'is_decoder'
]
=
False
# Make sure the encoder will be an encoder
kwargs
_encoder
[
'is_decoder'
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
encoder
=
AutoModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
_encoder
)
# Load and initialize the decoder
decoder
=
kwargs
.
pop
(
'decoder_model'
,
None
)
if
decoder_model
:
if
decoder
is
None
:
decoder
=
decoder_model
kwargs_decoder
[
'is_decoder'
]
=
True
else
:
decoder_model
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
kwargs
.
update
(
decoder_kwargs
)
# Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs
[
'is_decoder'
]
=
True
# Make sure the decoder will be a decoder
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs
)
model
=
cls
(
encoder
,
decoder
)
model
=
cls
(
encoder
,
decoder
)
return
model
return
model
def
forward
(
self
,
*
inputs
,
*
kwargs
):
def
forward
(
self
,
*
inputs
,
*
*
kwargs
):
# Extract decoder inputs
# Extract decoder inputs
decoder_kwargs
=
{}
decoder_kwargs
=
{}
for
key
in
kwargs
.
keys
():
for
key
in
kwargs
.
keys
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment