Unverified Commit 35d55b7b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

When resuming training from checkpoint, Trainer loads model (#9818)

* Whenresuming training from checkpoint, Trainer loads model

* Finish cleaning tests

* Address review comment

* Use global_step from state
parent 6b6c2b48
...@@ -688,20 +688,31 @@ class Trainer: ...@@ -688,20 +688,31 @@ class Trainer:
self._hp_search_setup(trial) self._hp_search_setup(trial)
# Model re-init # Model re-init
model_reloaded = False
if self.model_init is not None: if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init. # Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed) set_seed(self.args.seed)
self.model = self.call_model_init(trial)
model = self.call_model_init(trial) model_reloaded = True
if not self.is_model_parallel:
model = model.to(self.args.device)
self.model = model
self.model_wrapped = model
# Reinitializes optimizer and scheduler # Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint
if model_path is not None and os.path.isfile(os.path.join(model_path, WEIGHTS_NAME)):
logger.info(f"Loading model from {model_path}).")
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(model_path)
model_reloaded = True
else:
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if not self.is_model_parallel:
self.model = self.model.to(self.args.device)
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not # Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
...@@ -849,7 +860,7 @@ class Trainer: ...@@ -849,7 +860,7 @@ class Trainer:
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0 self._total_loss_scalar = 0.0
self._globalstep_last_logged = 0 self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos self._total_flos = self.state.total_flos
model.zero_grad() model.zero_grad()
......
...@@ -578,9 +578,8 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -578,9 +578,8 @@ class TrainerIntegrationTest(unittest.TestCase):
checkpoint = os.path.join(tmpdir, "checkpoint-5") checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model # Reinitialize trainer
model = RegressionPreTrainedModel.from_pretrained(checkpoint) trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint) trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -593,8 +592,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -593,8 +592,7 @@ class TrainerIntegrationTest(unittest.TestCase):
checkpoint = os.path.join(tmpdir, "checkpoint-15") checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model # Reinitialize trainer and load model
model = RegressionPreTrainedModel.from_pretrained(checkpoint) trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint) trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -615,10 +613,9 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -615,10 +613,9 @@ class TrainerIntegrationTest(unittest.TestCase):
checkpoint = os.path.join(tmpdir, "checkpoint-5") checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model # Reinitialize trainer and load model
model = RegressionModel() trainer = get_regression_trainer(
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
model.load_state_dict(state_dict) )
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint) trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -631,10 +628,9 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -631,10 +628,9 @@ class TrainerIntegrationTest(unittest.TestCase):
checkpoint = os.path.join(tmpdir, "checkpoint-15") checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model # Reinitialize trainer and load model
model = RegressionModel() trainer = get_regression_trainer(
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
model.load_state_dict(state_dict) )
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint) trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -664,9 +660,15 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -664,9 +660,15 @@ class TrainerIntegrationTest(unittest.TestCase):
checkpoint = os.path.join(tmpdir, "checkpoint-5") checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model # Reinitialize trainer
model = RegressionPreTrainedModel.from_pretrained(checkpoint) trainer = get_regression_trainer(
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) output_dir=tmpdir,
train_len=128,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.train(model_path=checkpoint) trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment