Commit 9b71fc9a authored by Rémi Louf's avatar Rémi Louf
Browse files

tying weights is going to be a clusterfuck

parent 95ec1d08
...@@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary ...@@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module): class PreTrainedSeq2seq(PreTrainedModel):
r""" r"""
:class:`~transformers.Seq2seq` is a generic model class that will be :class:`~transformers.Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of instantiated as a Seq2seq model with one of the base model classes of
...@@ -36,13 +36,20 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -36,13 +36,20 @@ class PreTrainedSeq2seq(nn.Module):
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
method. method.
""" """
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__() super(PreTrainedSeq2seq, self).__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
@classmethod @classmethod
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(
cls,
encoder_pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path,
*model_args,
**kwargs
):
r""" Instantiates an encoder and a decoder from one or two base classes r""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints. of the library from pre-trained model checkpoints.
...@@ -110,21 +117,25 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -110,21 +117,25 @@ class PreTrainedSeq2seq(nn.Module):
kwargs_decoder = {} kwargs_decoder = {}
kwargs_encoder = kwargs kwargs_encoder = kwargs
for key in kwargs_encoder.keys(): for key in kwargs_encoder.keys():
if key.startswith('decoder_'): if key.startswith("decoder_"):
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key) kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
# Load and initialize the encoder and decoder # Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made # The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly. # by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs.pop('encoder_model', None) encoder = kwargs.pop("encoder_model", None)
if encoder is None: if encoder is None:
kwargs_encoder['is_decoder'] = False kwargs_encoder["is_decoder"] = False
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) encoder = AutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
decoder = kwargs.pop('decoder_model', None) decoder = kwargs.pop("decoder_model", None)
if decoder is None: if decoder is None:
kwargs_decoder['is_decoder'] = True kwargs_decoder["is_decoder"] = True
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = AutoModelWithLMHead.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder
)
model = cls(encoder, decoder) model = cls(encoder, decoder)
...@@ -153,11 +164,11 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -153,11 +164,11 @@ class PreTrainedSeq2seq(nn.Module):
kwargs_decoder = {} kwargs_decoder = {}
kwargs_encoder = kwargs kwargs_encoder = kwargs
for key in kwargs_encoder.keys(): for key in kwargs_encoder.keys():
if key.startswith('decoder_'): if key.startswith("decoder_"):
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key) kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None) encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
...@@ -165,29 +176,49 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -165,29 +176,49 @@ class PreTrainedSeq2seq(nn.Module):
encoder_outputs = () encoder_outputs = ()
# Decode # Decode
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq): class Model2Model(PreTrainedSeq2seq):
def tie_weights(): def __init__(self):
# We should tie encoder and decoder embeddings if possible here super(Model2Model, self).__init__()
pass self.tie_weights()
def tie_weights(self):
""" Tying the encoder and decoders' embeddings together.
We need for each to get down to the embedding weights. However the
different model classes are inconsistent to that respect:
- BertModel: embeddings.word_embeddings
- RoBERTa: embeddings.word_embeddings
- XLMModel: embeddings
- GPT2: wte
- BertForMaskedLM: bert.embeddings.word_embeddings
- RobertaForMaskedLM: roberta.embeddings.word_embeddings
argument of the XEmbedding layer for each model, but it is "blocked"
by a model-specific keyword (bert, )...
"""
# self._tie_or_clone_weights(self.encoder, self.decoder)
raise NotImplementedError
class Model2LSTM(PreTrainedSeq2seq): class Model2LSTM(PreTrainedSeq2seq):
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
if kwargs.get('decoder_model', None) is None: if kwargs.get("decoder_model", None) is None:
# We will create a randomly initilized LSTM model as decoder # We will create a randomly initilized LSTM model as decoder
if 'decoder_config' not in kwargs: if "decoder_config" not in kwargs:
raise ValueError("To load an LSTM in Seq2seq model, please supply either: " raise ValueError(
"To load an LSTM in Seq2seq model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or" " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a dictionary of configuration parameters that will be used to initialize a" " - a dictionary of configuration parameters that will be used to initialize a"
" torch.nn.LSTM model as `decoder_config` keyword argument. " " torch.nn.LSTM model as `decoder_config` keyword argument. "
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`") " E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`"
kwargs['decoder_model'] = torch.nn.LSTM(kwargs.pop('decoder_config')) )
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs) model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment