Unverified Commit 3b8e2932 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Rework tests to compare trainer checkpoint args (#29883)



* Start rework

* Fix failing test

* Include max

* Update src/transformers/trainer.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 6e584070
...@@ -1565,23 +1565,28 @@ class Trainer: ...@@ -1565,23 +1565,28 @@ class Trainer:
"logging_steps": "logging_steps", "logging_steps": "logging_steps",
"eval_steps": "eval_steps", "eval_steps": "eval_steps",
"save_steps": "save_steps", "save_steps": "save_steps",
"per_device_train_batch_size": "train_batch_size",
} }
warnings_list = [] has_warning = False
warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
for arg_attr, state_attr in attributes_map.items(): for arg_attr, state_attr in attributes_map.items():
arg_value = getattr(training_args, arg_attr, None) arg_value = getattr(training_args, arg_attr, None)
state_value = getattr(trainer_state, state_attr, None) state_value = getattr(trainer_state, state_attr, None)
if arg_value is not None and state_value is not None and arg_value != state_value: if arg_value is not None and state_value is not None and arg_value != state_value:
warnings_list.append( warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). " has_warning = True
f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory."
) # train bs is special as we need to account for multi-GPU
train_bs_args = training_args.per_device_train_batch_size
train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
if train_bs_args != train_bs_state:
warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
has_warning = True
if warnings_list: if has_warning:
for warning in warnings_list: logger.warning_once(warning_str)
logger.warning(warning)
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex: if self.args.use_ipex:
......
...@@ -2540,16 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2540,16 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
) )
checkpoint_trainer.train(resume_from_checkpoint=checkpoint) checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out)
self.assertIn( self.assertIn(
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", "per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)",
cl.out,
)
self.assertIn(
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out, cl.out,
) )
self.assertIn( self.assertIn(
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", "eval_steps: 10 (from args) != 5 (from trainer_state.json)",
cl.out, cl.out,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment