Unverified Commit d63ab615 authored by Kai Fricke's avatar Kai Fricke Committed by GitHub
Browse files

Use object store to pass trainer object to Ray Tune (#9749)

parent 6312fed4
...@@ -149,20 +149,20 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -149,20 +149,20 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray import ray
def _objective(trial, checkpoint_dir=None): def _objective(trial, local_trainer, checkpoint_dir=None):
model_path = None model_path = None
if checkpoint_dir: if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir): for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR): if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir) model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None local_trainer.objective = None
trainer.train(model_path=model_path, trial=trial) local_trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None: if getattr(local_trainer, "objective", None) is None:
metrics = trainer.evaluate() metrics = local_trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics) local_trainer.objective = local_trainer.compute_objective(metrics)
trainer._tune_save_checkpoint() local_trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective, **metrics, done=True) ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists) # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search. # while doing the ray hp search.
...@@ -217,7 +217,12 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -217,7 +217,12 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__) "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
) )
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs) analysis = ray.tune.run(
ray.tune.with_parameters(_objective, local_trainer=trainer),
config=trainer.hp_space(None),
num_samples=n_trials,
**kwargs,
)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
if _tb_writer is not None: if _tb_writer is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment