"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d8e3bdbb4cce939e8f95e0f1fa33bdd7350f4b79"
Unverified Commit cc840752 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix epoch number when resuming training (#21478)

parent 35f93f29
......@@ -1798,8 +1798,10 @@ class Trainer:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
......@@ -1907,7 +1909,7 @@ class Trainer:
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
......
......@@ -1148,7 +1148,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# won't be the same since the training dataloader is shuffled).
with tempfile.TemporaryDirectory() as tmpdir:
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, logging_steps=5)
trainer = get_regression_trainer(**kwargs)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment