"vscode:/vscode.git/clone" did not exist on "ef176d29dffe5fb9f962d5d8c2aa8dcfb7ba2464"
Unverified Commit b5a6d6ee authored by Jonathan Flynn's avatar Jonathan Flynn Committed by GitHub
Browse files

Add warnings if training args differ from checkpoint trainer state (#29255)



* add warnings if training args differ from checkpoint args stored in trainer_state.json

* run formatting and styling

* add a test

* format and styling

---------
Co-authored-by: default avatarJonathan Flynn <jonl.flynn@guardian.co.uk>
parent 7eb3ba82
...@@ -1529,6 +1529,29 @@ class Trainer: ...@@ -1529,6 +1529,29 @@ class Trainer:
return model return model
def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
attributes_map = {
"logging_steps": "logging_steps",
"eval_steps": "eval_steps",
"save_steps": "save_steps",
"per_device_train_batch_size": "train_batch_size",
}
warnings_list = []
for arg_attr, state_attr in attributes_map.items():
arg_value = getattr(training_args, arg_attr, None)
state_value = getattr(trainer_state, state_attr, None)
if arg_value is not None and state_value is not None and arg_value != state_value:
warnings_list.append(
f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). "
f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory."
)
if warnings_list:
for warning in warnings_list:
logger.warning(warning)
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex: if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
...@@ -1991,6 +2014,7 @@ class Trainer: ...@@ -1991,6 +2014,7 @@ class Trainer:
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
): ):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip: if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
......
...@@ -2485,6 +2485,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2485,6 +2485,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5") trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25]) self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
def test_compare_trainer_and_checkpoint_args_logging(self):
logger = logging.get_logger()
with tempfile.TemporaryDirectory() as tmpdir, CaptureLogger(logger) as cl:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
eval_steps=5,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.train()
checkpoint = os.path.join(tmpdir, "checkpoint-5")
checkpoint_trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=256,
eval_steps=10,
gradient_accumulation_steps=4,
per_device_train_batch_size=8,
save_steps=10,
learning_rate=0.1,
)
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
self.assertIn(
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
self.assertIn(
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
self.assertIn(
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
def check_mem_metrics(self, trainer, check_func): def check_mem_metrics(self, trainer, check_func):
metrics = trainer.train().metrics metrics = trainer.train().metrics
check_func("init_mem_cpu_alloc_delta", metrics) check_func("init_mem_cpu_alloc_delta", metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment