Unverified Commit daf281f4 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Enforce saving at end of training if saving option chosen (#30160)

* Enforce saving at end of training

* Fix test

* Rework test

* Fixup tests'

* Update comment based on sourab feedback

* Clean
parent 7a4792e6
...@@ -544,6 +544,9 @@ class DefaultFlowCallback(TrainerCallback): ...@@ -544,6 +544,9 @@ class DefaultFlowCallback(TrainerCallback):
# End training # End training
if state.global_step >= state.max_steps: if state.global_step >= state.max_steps:
control.should_training_stop = True control.should_training_stop = True
# Save the model at the end if we have a save strategy
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True
return control return control
......
...@@ -335,6 +335,9 @@ class TrainingArguments: ...@@ -335,6 +335,9 @@ class TrainingArguments:
- `"no"`: No save is done during training. - `"no"`: No save is done during training.
- `"epoch"`: Save is done at the end of each epoch. - `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`. - `"steps"`: Save is done every `save_steps`.
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
very end of training, always.
save_steps (`int` or `float`, *optional*, defaults to 500): save_steps (`int` or `float`, *optional*, defaults to 500):
Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps. float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
......
...@@ -129,6 +129,7 @@ if is_torch_available(): ...@@ -129,6 +129,7 @@ if is_torch_available():
if is_safetensors_available(): if is_safetensors_available():
import safetensors.torch import safetensors.torch
# for version specific tests in TrainerIntegrationTest # for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
...@@ -2016,6 +2017,56 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2016,6 +2017,56 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
) )
def test_load_best_model_with_save(self):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
save_steps=5,
evaluation_strategy="steps",
eval_steps=5,
max_steps=9,
)
trainer.train()
# Check that we have the last known step:
assert os.path.exists(
os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")
), f"Could not find checkpoint-{trainer.state.max_steps}"
# And then check the last step
assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9"
# Now test that using a limit works
# Should result in:
# - save at step 5 (but is deleted)
# - save at step 10 (loaded in at the end when `load_best_model=True`)
# - save at step 11
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
save_steps=5,
evaluation_strategy="steps",
eval_steps=5,
load_best_model_at_end=True,
save_total_limit=2,
max_steps=11,
)
trainer.train()
# Check that we have the last known step:
assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11"
# And then check the last multiple
assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10"
# Finally check that we don't have an old one
assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected"
# Finally check that the right model was loaded in, checkpoint-10
# this goes by the last `eval` step check to do so, so it won't be
# the last model *saved*
model_state = trainer.model.state_dict()
final_model_weights = safetensors.torch.load_file(
os.path.join(tmpdir, "checkpoint-10", "model.safetensors")
)
for k, v in model_state.items():
assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same"
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_run_seq2seq_double_train_wrap_once(self): def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once # test that we don't wrap the model more than once
......
...@@ -153,7 +153,7 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -153,7 +153,7 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events.append("on_log") expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
expected_events += evaluation_events.copy() expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0: if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps:
expected_events.append("on_save") expected_events.append("on_save")
expected_events.append("on_epoch_end") expected_events.append("on_epoch_end")
if trainer.args.eval_strategy == IntervalStrategy.EPOCH: if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment