"...zh_cn/git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "f256abffd38e848211c208ac706a2e35b1ac2e94"
Unverified Commit b7439675 authored by Philip May's avatar Philip May Committed by GitHub
Browse files

fix `Trainer.train(resume_from_checkpoint=False)` is causing an exception (#12981)



* fix #12970

* Update tests/test_trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove unnecessary issue link

* fix test formatting
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 790f1c95
...@@ -1005,6 +1005,7 @@ class Trainer: ...@@ -1005,6 +1005,7 @@ class Trainer:
kwargs: kwargs:
Additional keyword arguments used to hide deprecated arguments Additional keyword arguments used to hide deprecated arguments
""" """
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker.start() self._memory_tracker.start()
......
...@@ -827,6 +827,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -827,6 +827,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertAlmostEqual(a, a1, delta=1e-8) self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8) self.assertAlmostEqual(b, b1, delta=1e-8)
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_flase(self):
train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset()
config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)
tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train(resume_from_checkpoint=False)
@require_torch_up_to_2_gpus @require_torch_up_to_2_gpus
def test_resume_training_with_gradient_accumulation(self): def test_resume_training_with_gradient_accumulation(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment