Unverified Commit 49629e7b authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix HPO DDP GPU problem (#19168)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 8d59385f
...@@ -23,6 +23,7 @@ import pickle ...@@ -23,6 +23,7 @@ import pickle
import shutil import shutil
import sys import sys
import tempfile import tempfile
from dataclasses import asdict
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Dict, Optional
...@@ -195,9 +196,10 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -195,9 +196,10 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.") raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank, src=0) torch.distributed.broadcast_object_list(args_main_rank, src=0)
local_rank = trainer.args.local_rank # backup the local_rank info args = pickle.loads(bytes(args_main_rank))
trainer.args = pickle.loads(bytes(args_main_rank)) for key, value in asdict(args).items():
trainer.args.local_rank = local_rank if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None) trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
...@@ -429,9 +431,10 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -429,9 +431,10 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.") raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank, src=0) torch.distributed.broadcast_object_list(args_main_rank, src=0)
local_rank = trainer.args.local_rank # backup the local_rank info args = pickle.loads(bytes(args_main_rank))
trainer.args = pickle.loads(bytes(args_main_rank)) for key, value in asdict(args).items():
trainer.args.local_rank = local_rank if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None) trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
...@@ -470,7 +473,6 @@ def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> Bes ...@@ -470,7 +473,6 @@ def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> Bes
sweep_config["name"] = name sweep_config["name"] = name
def _objective(): def _objective():
run = wandb.run if wandb.run else wandb.init() run = wandb.run if wandb.run else wandb.init()
trainer.state.trial_name = run.name trainer.state.trial_name = run.name
run.config.update({"assignments": {}, "metric": metric}) run.config.update({"assignments": {}, "metric": metric})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment