Unverified Commit 6757ed28 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Allow `resume_from_checkpoint` to handle `auto_find_batch_size` (#27568)



* Fuffill request

* Add test

* Better test

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Better test

* Better test

* MOre comments

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent aa7ab98e
...@@ -1507,6 +1507,10 @@ class Trainer: ...@@ -1507,6 +1507,10 @@ class Trainer:
and not self.is_fsdp_enabled and not self.is_fsdp_enabled
): ):
self._load_from_checkpoint(resume_from_checkpoint) self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
if state.train_batch_size is not None:
self._train_batch_size = state.train_batch_size
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
...@@ -1542,6 +1546,8 @@ class Trainer: ...@@ -1542,6 +1546,8 @@ class Trainer:
): ):
self.accelerator.free_memory() self.accelerator.free_memory()
self._train_batch_size = batch_size self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps # Data loader and number of training steps
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
...@@ -1618,6 +1624,7 @@ class Trainer: ...@@ -1618,6 +1624,7 @@ class Trainer:
self.state = TrainerState() self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio # Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None: if args.logging_steps is not None:
......
...@@ -59,6 +59,9 @@ class TrainerState: ...@@ -59,6 +59,9 @@ class TrainerState:
Run an evaluation every X steps. Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500): save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps. Save checkpoint every X updates steps.
train_batch_size (`int`, *optional*):
The batch size for the training dataloader. Only needed when
`auto_find_batch_size` has been used.
num_input_tokens_seen (`int`, *optional*, defaults to 0): num_input_tokens_seen (`int`, *optional*, defaults to 0):
The number of tokens seen during training (number of input tokens, not the number of prediction tokens). The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
total_flos (`float`, *optional*, defaults to 0): total_flos (`float`, *optional*, defaults to 0):
...@@ -88,6 +91,7 @@ class TrainerState: ...@@ -88,6 +91,7 @@ class TrainerState:
logging_steps: int = 500 logging_steps: int = 500
eval_steps: int = 500 eval_steps: int = 500
save_steps: int = 500 save_steps: int = 500
train_batch_size: int = None
num_train_epochs: int = 0 num_train_epochs: int = 0
num_input_tokens_seen: int = 0 num_input_tokens_seen: int = 0
total_flos: float = 0 total_flos: float = 0
......
...@@ -38,6 +38,7 @@ from transformers import ( ...@@ -38,6 +38,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
IntervalStrategy, IntervalStrategy,
PretrainedConfig, PretrainedConfig,
TrainerCallback,
TrainingArguments, TrainingArguments,
get_polynomial_decay_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup,
is_torch_available, is_torch_available,
...@@ -1546,6 +1547,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1546,6 +1547,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_glue.main() run_glue.main()
def test_auto_batch_size_with_resume_from_checkpoint(self):
train_dataset = RegressionDataset(length=128)
config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)
tmp_dir = self.get_auto_remove_tmp_dir()
class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size == 16:
raise RuntimeError("CUDA out of memory.")
args = RegressionTrainingArguments(
tmp_dir,
do_train=True,
max_steps=2,
save_steps=1,
per_device_train_batch_size=16,
auto_find_batch_size=True,
)
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
trainer.train()
# After `auto_find_batch_size` is ran we should now be at 8
self.assertEqual(trainer._train_batch_size, 8)
# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16)
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)
# regression for this issue: https://github.com/huggingface/transformers/issues/12970 # regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self): def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128) train_dataset = RegressionDataset(length=128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment