Unverified Commit 7c10dd22 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Better support for resuming training (#8878)

parent 21db560d
...@@ -665,12 +665,12 @@ class Trainer: ...@@ -665,12 +665,12 @@ class Trainer:
) )
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", num_examples) logger.info(f" Num examples = {num_examples}")
logger.info(" Num Epochs = %d", num_train_epochs) logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(" Total optimization steps = %d", max_steps) logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0 self.state.epoch = 0
epochs_trained = 0 epochs_trained = 0
...@@ -680,13 +680,20 @@ class Trainer: ...@@ -680,13 +680,20 @@ class Trainer:
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) if not self.args.ignore_data_skip:
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(" Continuing training from global step %d", self.state.global_step) logger.info(f" Continuing training from global step {self.state.global_step}")
logger.info(" Will skip the first %d batches in the first epoch", steps_trained_in_current_epoch) if not self.args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch."
)
# Update the references # Update the references
self.callback_handler.model = self.model self.callback_handler.model = self.model
...@@ -712,6 +719,13 @@ class Trainer: ...@@ -712,6 +719,13 @@ class Trainer:
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not self.args.ignore_data_skip:
for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
......
...@@ -189,6 +189,10 @@ class TrainingArguments: ...@@ -189,6 +189,10 @@ class TrainingArguments:
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there are more than one devices, whether to use model parallelism to distribute the model's modules If there are more than one devices, whether to use model parallelism to distribute the model's modules
across devices or not. across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -350,6 +354,12 @@ class TrainingArguments: ...@@ -350,6 +354,12 @@ class TrainingArguments:
greater_is_better: Optional[bool] = field( greater_is_better: Optional[bool] = field(
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
) )
ignore_data_skip: bool = field(
default=False,
metadata={
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:
......
...@@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.assertEqual(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
# With a regular model that is not a PreTrainedModel # With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer( trainer = get_regression_trainer(
...@@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.assertEqual(state, state1) self.assertEqual(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
model.load_state_dict(state_dict)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
def test_resume_training_with_gradient_accumulation(self): def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2: if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment