"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "963db81a5afc5c31edafa0c9fb0390956a78cf2a"
Unverified Commit d14af22c authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

add DDP HPO support for optuna (#19002)



only main_process will have HPO, and pass argument to other process
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 00fc9217
...@@ -159,26 +159,49 @@ def default_hp_search_backend(): ...@@ -159,26 +159,49 @@ def default_hp_search_backend():
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna import optuna
def _objective(trial, checkpoint_dir=None): if trainer.args.process_index == 0:
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective
timeout = kwargs.pop("timeout", None) def _objective(trial, checkpoint_dir=None):
n_jobs = kwargs.pop("n_jobs", 1) checkpoint = None
study = optuna.create_study(direction=direction, **kwargs) if checkpoint_dir:
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) for subdir in os.listdir(checkpoint_dir):
best_trial = study.best_trial if subdir.startswith(PREFIX_CHECKPOINT_DIR):
return BestRun(str(best_trial.number), best_trial.value, best_trial.params) checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer._hp_search_setup(trial)
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=checkpoint)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
else:
for i in range(n_trials):
trainer.objective = None
args_main_rank = list(pickle.dumps(trainer.args))
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank, src=0)
local_rank = trainer.args.local_rank # backup the local_rank info
trainer.args = pickle.loads(bytes(args_main_rank))
trainer.args.local_rank = local_rank
trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return None
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
......
...@@ -1210,7 +1210,7 @@ class Trainer: ...@@ -1210,7 +1210,7 @@ class Trainer:
value = type(old_attr)(value) value = type(old_attr)(value)
setattr(self.args, key, value) setattr(self.args, key, value)
if self.hp_search_backend == HPSearchBackend.OPTUNA: if self.hp_search_backend == HPSearchBackend.OPTUNA:
logger.info("Trial:", trial.params) logger.info(f"Trial: {trial.params}")
if self.hp_search_backend == HPSearchBackend.SIGOPT: if self.hp_search_backend == HPSearchBackend.SIGOPT:
logger.info(f"SigOpt Assignments: {trial.assignments}") logger.info(f"SigOpt Assignments: {trial.assignments}")
if self.hp_search_backend == HPSearchBackend.WANDB: if self.hp_search_backend == HPSearchBackend.WANDB:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment