Unverified Commit 6227078d authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

HPO: keep the original logic if there's only one process, pass the trial to trainer (#19096)



need to find out solution for following cases
     *if we need to use trial in model_init, how to do it for non-main rank, sync the model with rank0 in app?
     *how to use optuna prune feature for DDP, if we do it in rank0, how does other rank know it.
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 3b0cecb6
...@@ -168,12 +168,14 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -168,12 +168,14 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
if subdir.startswith(PREFIX_CHECKPOINT_DIR): if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir) checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None trainer.objective = None
trainer._hp_search_setup(trial)
if trainer.args.world_size > 1: if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.") raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(trial)
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0) torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=checkpoint) trainer.train(resume_from_checkpoint=checkpoint)
else:
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate() metrics = trainer.evaluate()
...@@ -362,12 +364,14 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -362,12 +364,14 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
for run in experiment.loop(): for run in experiment.loop():
with run: with run:
trainer.objective = None trainer.objective = None
trainer._hp_search_setup(run.run)
if trainer.args.world_size > 1: if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.") raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(run.run)
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0) torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=None) trainer.train(resume_from_checkpoint=None)
else:
trainer.train(resume_from_checkpoint=None, trial=run.run)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate() metrics = trainer.evaluate()
...@@ -397,12 +401,14 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -397,12 +401,14 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
while experiment.progress.observation_count < experiment.observation_budget: while experiment.progress.observation_count < experiment.observation_budget:
suggestion = conn.experiments(experiment.id).suggestions().create() suggestion = conn.experiments(experiment.id).suggestions().create()
trainer.objective = None trainer.objective = None
trainer._hp_search_setup(suggestion)
if trainer.args.world_size > 1: if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.") raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(suggestion)
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0) torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=None) trainer.train(resume_from_checkpoint=None)
else:
trainer.train(resume_from_checkpoint=None, trial=suggestion)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate() metrics = trainer.evaluate()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment