Unverified Commit b17a5e00 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Fix issue #19300 (#19483)



* Fix issue #19300

* Fixing import order

* Fix issue #19300

* Fix formatting issues

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Refactor method

* Refactor method

* Fix the issue of sending wrong output dir

* Remove unused code
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d2ed8134
...@@ -1876,10 +1876,40 @@ class Trainer: ...@@ -1876,10 +1876,40 @@ class Trainer:
self.log(metrics) self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint.
if self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if checkpoint != self.state.best_model_checkpoint:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)
self.control = self.callback_handler.on_train_end(args, self.state, self.control) self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics) return TrainOutput(self.state.global_step, train_loss, metrics)
def _get_output_dir(self, trial):
if self.hp_search_backend is not None and trial is not None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
run_id = tune.get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB:
import wandb
run_id = wandb.run.id
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
run_dir = self.args.output_dir
return run_dir
def _load_from_checkpoint(self, resume_from_checkpoint, model=None): def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None: if model is None:
...@@ -2105,25 +2135,10 @@ class Trainer: ...@@ -2105,25 +2135,10 @@ class Trainer:
# Save model checkpoint # Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend is None and trial is None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
run_id = tune.get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB:
import wandb
run_id = wandb.run.id
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
run_dir = self.args.output_dir
self.store_flos() self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True) self.save_model(output_dir, _internal_call=True)
if self.deepspeed: if self.deepspeed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment