Unverified Commit 77384941 authored by Shai Erera's avatar Shai Erera Committed by GitHub
Browse files

Use model.from_pretrained for DataParallel also (#8795)

* Use model.from_pretrained for DataParallel also

When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end.

This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel.

If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden?

* Fix silly bug

* Address review comments

Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right?
parent 4062c75e
...@@ -808,8 +808,8 @@ class Trainer: ...@@ -808,8 +808,8 @@ class Trainer:
logger.info( logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
) )
if isinstance(model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint) self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.args.model_parallel: if not self.args.model_parallel:
self.model = self.model.to(self.args.device) self.model = self.model.to(self.args.device)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment