Unverified Commit 3318c246 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

make failure to find a resume checkpoint fatal + tests (#10777)

parent cd8c93f7
...@@ -876,7 +876,10 @@ class Trainer: ...@@ -876,7 +876,10 @@ class Trainer:
if resume_from_checkpoint is None: if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}).") logger.info(f"Loading model from {resume_from_checkpoint}).")
if self.deepspeed: if self.deepspeed:
......
...@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
return return
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer = get_regression_trainer(**kwargs)
trainer.train() trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item() (a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state) state = dataclasses.asdict(trainer.state)
...@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-5") checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer # Reinitialize trainer
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-15") checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model # Reinitialize trainer and load model
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# With a regular model that is not a PreTrainedModel # With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer( kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False)
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
) trainer = get_regression_trainer(**kwargs)
trainer.train() trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item() (a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state) state = dataclasses.asdict(trainer.state)
...@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-5") checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model # Reinitialize trainer and load model
trainer = get_regression_trainer( trainer = get_regression_trainer(**kwargs)
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer.train(resume_from_checkpoint=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-15") checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model # Reinitialize trainer and load model
trainer = get_regression_trainer( trainer = get_regression_trainer(**kwargs)
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer.train(resume_from_checkpoint=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item() (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
...@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1) self.check_trainer_state_are_the_same(state, state1)
# Now check failures
# 1. fail to find a bogus checkpoint
trainer = get_regression_trainer()
with self.assertRaises(Exception) as context:
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
self.assertTrue("Can't find a valid checkpoint at" in str(context.exception))
# 2. fail to find any checkpoint - due a fresh output_dir
output_dir2 = self.get_auto_remove_tmp_dir()
trainer = get_regression_trainer(output_dir=output_dir2)
with self.assertRaises(Exception) as context:
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
def test_resume_training_with_gradient_accumulation(self): def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2: if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment