Unverified Commit a59bcefb authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Split hp search methods (#6857)

* Split the run_hp_search by backend

* Unused import
parent 23f9611c
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import numpy as np import numpy as np
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, HPSearchBackend from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
from transformers.utils import logging from transformers.utils import logging
...@@ -83,7 +83,7 @@ def default_hp_search_backend(): ...@@ -83,7 +83,7 @@ def default_hp_search_backend():
return "ray" return "ray"
def run_hp_search(trainer, n_trials, direction, kwargs): def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
def _objective(trial, checkpoint_dir=None): def _objective(trial, checkpoint_dir=None):
model_path = None model_path = None
if checkpoint_dir: if checkpoint_dir:
...@@ -96,19 +96,33 @@ def run_hp_search(trainer, n_trials, direction, kwargs): ...@@ -96,19 +96,33 @@ def run_hp_search(trainer, n_trials, direction, kwargs):
if getattr(trainer, "objective", None) is None: if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate() metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics) trainer.objective = trainer.compute_objective(metrics)
if trainer.hp_search_backend == HPSearchBackend.RAY:
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective)
return trainer.objective return trainer.objective
if trainer.hp_search_backend == HPSearchBackend.OPTUNA:
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1) n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs) study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial best_trial = study.best_trial
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params) return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
elif trainer.hp_search_backend == HPSearchBackend.RAY:
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective)
return trainer.objective
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists) # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search. # while doing the ray hp search.
_tb_writer = trainer.tb_writer _tb_writer = trainer.tb_writer
...@@ -137,12 +151,7 @@ def run_hp_search(trainer, n_trials, direction, kwargs): ...@@ -137,12 +151,7 @@ def run_hp_search(trainer, n_trials, direction, kwargs):
"consider setting `keep_checkpoints_num=1`." "consider setting `keep_checkpoints_num=1`."
) )
if "scheduler" in kwargs: if "scheduler" in kwargs:
from ray.tune.schedulers import ( from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
ASHAScheduler,
HyperBandForBOHB,
MedianStoppingRule,
PopulationBasedTraining,
)
# Check if checkpointing is enabled for PopulationBasedTraining # Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining): if isinstance(kwargs["scheduler"], PopulationBasedTraining):
...@@ -171,5 +180,4 @@ def run_hp_search(trainer, n_trials, direction, kwargs): ...@@ -171,5 +180,4 @@ def run_hp_search(trainer, n_trials, direction, kwargs):
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
trainer.tb_writer = _tb_writer trainer.tb_writer = _tb_writer
return best_run return best_run
...@@ -27,7 +27,8 @@ from .integrations import ( ...@@ -27,7 +27,8 @@ from .integrations import (
is_ray_available, is_ray_available,
is_tensorboard_available, is_tensorboard_available,
is_wandb_available, is_wandb_available,
run_hp_search, run_hp_search_optuna,
run_hp_search_ray,
) )
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
...@@ -884,7 +885,8 @@ class Trainer: ...@@ -884,7 +885,8 @@ class Trainer:
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
best_run = run_hp_search(self, n_trials, direction, kwargs) run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
best_run = run_hp_search(self, n_trials, direction, **kwargs)
self.hp_search_backend = None self.hp_search_backend = None
return best_run return best_run
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment