Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56e2ee4e
Commit
56e2ee4e
authored
Oct 17, 2019
by
thomwolf
Browse files
fix model2model
parent
8cd56e30
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
5 deletions
+12
-5
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+12
-5
No files found.
transformers/modeling_seq2seq.py
View file @
56e2ee4e
...
...
@@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
logger
=
logging
.
getLogger
(
__name__
)
class
PreTrainedSeq2seq
(
PreTrainedModel
):
class
PreTrainedSeq2seq
(
nn
.
Module
):
r
"""
:class:`~transformers.Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
...
...
@@ -43,7 +43,7 @@ class PreTrainedSeq2seq(PreTrainedModel):
self
.
decoder
=
decoder
@
classmethod
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
r
""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
...
...
@@ -177,8 +177,8 @@ class PreTrainedSeq2seq(PreTrainedModel):
class
Model2Model
(
PreTrainedSeq2seq
):
def
__init__
(
self
):
super
(
Model2Model
,
self
).
__init__
()
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Model2Model
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
@@ -197,7 +197,14 @@ class Model2Model(PreTrainedSeq2seq):
by a model-specific keyword (bert, )...
"""
# self._tie_or_clone_weights(self.encoder, self.decoder)
raise
NotImplementedError
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
return
model
class
Model2LSTM
(
PreTrainedSeq2seq
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment