Commit 6d6c3267 authored by Rémi Louf's avatar Rémi Louf
Browse files

take path to pretrained for encoder and decoder for init

parent 0d81fc85
...@@ -21,21 +21,20 @@ import logging ...@@ -21,21 +21,20 @@ import logging
import torch import torch
from torch import nn from torch import nn
from .file_utils import add_start_docstrings
from .modeling_auto import AutoModel, AutoModelWithLMHead from .modeling_auto import AutoModel, AutoModelWithLMHead
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module): class PreTrainedSeq2seq(nn.Module):
r""" r"""
:class:`~transformers.Seq2seq` is a generic model class :class:`~transformers.Seq2seq` is a generic model class that will be
that will be instantiated as a Seq2seq model with one of the base model classes of the library instantiated as a Seq2seq model with one of the base model classes of
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` the library as encoder and (optionally) as decoder when created with
class method. the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
method.
""" """
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__() super(PreTrainedSeq2seq, self).__init__()
...@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module):
self.decoder = decoder self.decoder = decoder
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library r""" Instantiates one of the base model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
...@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module):
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
# Extract encoder and decoder model if provided
encoder_model = kwargs.pop('encoder_model', None)
decoder_model = kwargs.pop('decoder_model', None)
# Extract decoder kwargs so we only have encoder kwargs for now # Separate the encoder- and decoder- specific kwargs. A kwarg is
if decoder_model is None: # decoder-specific it the key starts with `decoder_`
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path) kwargs_decoder = {}
decoder_kwargs = {} kwargs_encoder = kwargs
for key in kwargs.keys(): for key in kwargs_encoder.keys():
if key.startswith('decoder_'): if key.startswith('decoder_'):
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key) kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
# Load and initialize the decoder # Load and initialize the encoder and decoder
if encoder_model: # The distinction between encoder and decoder at the model level is made
encoder = encoder_model # by the value of the flag `is_decoder` that we need to set correctly.
else: encoder = kwargs.pop('encoder_model', None)
# Load and initialize the encoder if encoder is None:
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder kwargs_encoder['is_decoder'] = False
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder)
# Load and initialize the decoder decoder = kwargs.pop('decoder_model', None)
if decoder_model: if decoder is None:
decoder = decoder_model kwargs_decoder['is_decoder'] = True
else: decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
model = cls(encoder, decoder) model = cls(encoder, decoder)
return model return model
def forward(self, *inputs, *kwargs): def forward(self, *inputs, **kwargs):
# Extract decoder inputs # Extract decoder inputs
decoder_kwargs = {} decoder_kwargs = {}
for key in kwargs.keys(): for key in kwargs.keys():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment