Commit 1c71ecc8 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

load the pretrained weights for encoder-decoder

We currently save the pretrained_weights of the encoder and decoder in
two separate directories `encoder` and `decoder`. However, for the
`from_pretrained` function to operate with automodels we need to
specify the type of model in the path to the weights.

The path to the encoder/decoder weights is handled by the
`PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice
there is no easy way to infer the type of model that was initialized for
the encoder and decoder we add a parameter `model_type` to the function.
This is not an ideal solution as it is error prone, and the model type
should be carried by the Model classes somehow.

This is a temporary fix that should be changed before merging.
parent 07f4cd73
...@@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""):
return result return result
def save_model_checkpoints(args, model, tokenizer):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir, model_type='bert')
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -454,36 +470,30 @@ def main(): ...@@ -454,36 +470,30 @@ def main():
# Train the model # Train the model
model.to(args.device) model.to(args.device)
if args.do_train: if args.do_train:
global_step, tr_loss = train(args, model, tokenizer) try:
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) global_step, tr_loss = train(args, model, tokenizer)
except KeyboardInterrupt:
if not os.path.exists(args.output_dir): response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
os.makedirs(args.output_dir) if response.lower() in ["", "y", "yes"]:
save_model_checkpoints(args, model, tokenizer)
sys.exit(0)
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
save_model_checkpoints(args, model, tokenizer)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
# Evaluate the model # Evaluate the model
results = {} results = {}
if args.do_evaluate: if args.do_evaluate:
checkpoints = [] checkpoints = [args.output_dir]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder") encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder") decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
model = PreTrainedEncoderDecoder.from_pretrained( model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint encoder_checkpoint, decoder_checkpoint
) )
model.to(args.device) model.to(args.device)
results = "placeholder" print("model loaded")
return results return results
......
...@@ -117,8 +117,7 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -117,8 +117,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") if not argument.startswith("encoder_") and not argument.startswith("decoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
...@@ -158,14 +157,27 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -158,14 +157,27 @@ class PreTrainedEncoderDecoder(nn.Module):
return model return model
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory, model_type="bert"):
""" Save a Seq2Seq model and its configuration file in a format such """ Save an EncoderDecoder model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained` that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories. We save the encoder' and decoder's parameters in two separate directories.
If we want the weight loader to function we need to preprend the model
type to the directories' names. As far as I know there is no simple way
to infer the type of the model (except maybe by parsing the class'
names, which is not very future-proof). For now, we ask the user to
specify the model type explicitly when saving the weights.
""" """
self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) if not os.path.exists(encoder_path):
os.makedirs(encoder_path)
self.encoder.save_pretrained(encoder_path)
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
if not os.path.exists(decoder_path):
os.makedirs(decoder_path)
self.decoder.save_pretrained(decoder_path)
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs): def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing: """ The forward pass on a seq2eq depends what we are performing:
...@@ -193,8 +205,7 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -193,8 +205,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") if not argument.startswith("encoder_") and not argument.startswith("decoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
...@@ -217,9 +228,7 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -217,9 +228,7 @@ class PreTrainedEncoderDecoder(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[ encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
0
] # output the last layer hidden state
else: else:
encoder_outputs = () encoder_outputs = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment