Commit 6d6c3267 authored by Rémi Louf's avatar Rémi Louf
Browse files

take path to pretrained for encoder and decoder for init

parent 0d81fc85
......@@ -21,21 +21,20 @@ import logging
import torch
from torch import nn
from .file_utils import add_start_docstrings
from .modeling_auto import AutoModel, AutoModelWithLMHead
from .modeling_utils import PreTrainedModel, SequenceSummary
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
r"""
:class:`~transformers.Seq2seq` is a generic model class
that will be instantiated as a Seq2seq model with one of the base model classes of the library
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
:class:`~transformers.Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
method.
"""
def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__()
......@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
self.decoder = decoder
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
......@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
# Extract encoder and decoder model if provided
encoder_model = kwargs.pop('encoder_model', None)
decoder_model = kwargs.pop('decoder_model', None)
# Extract decoder kwargs so we only have encoder kwargs for now
if decoder_model is None:
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path)
decoder_kwargs = {}
for key in kwargs.keys():
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_decoder = {}
kwargs_encoder = kwargs
for key in kwargs_encoder.keys():
if key.startswith('decoder_'):
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
# Load and initialize the decoder
if encoder_model:
encoder = encoder_model
else:
# Load and initialize the encoder
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs.pop('encoder_model', None)
if encoder is None:
kwargs_encoder['is_decoder'] = False
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder)
# Load and initialize the decoder
if decoder_model:
decoder = decoder_model
else:
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
decoder = kwargs.pop('decoder_model', None)
if decoder is None:
kwargs_decoder['is_decoder'] = True
decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
model = cls(encoder, decoder)
return model
def forward(self, *inputs, *kwargs):
def forward(self, *inputs, **kwargs):
# Extract decoder inputs
decoder_kwargs = {}
for key in kwargs.keys():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment