Unverified Commit 23f9611c authored by krfricke's avatar krfricke Committed by GitHub
Browse files

Add checkpointing to Ray Tune HPO (#6747)

* Introduce HPO checkpointing for PBT

* Moved checkpoint saving

* Fixed checkpoint subdir pass

* Fixed style

* Enable/disable checkpointing, check conditions for various tune schedulers incl. PBT

* Adjust number of GPUs to number of jobs

* Avoid mode pickling in ray

* Move hp search to integrations
parent 61b7ba93
# Integrations with other Python libraries # Integrations with other Python libraries
import os import os
import numpy as np
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, HPSearchBackend
from transformers.utils import logging
logger = logging.get_logger(__name__)
try: try:
import comet_ml # noqa: F401 import comet_ml # noqa: F401
...@@ -75,3 +81,95 @@ def default_hp_search_backend(): ...@@ -75,3 +81,95 @@ def default_hp_search_backend():
return "optuna" return "optuna"
elif is_ray_available(): elif is_ray_available():
return "ray" return "ray"
def run_hp_search(trainer, n_trials, direction, kwargs):
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
if trainer.hp_search_backend == HPSearchBackend.RAY:
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective)
return trainer.objective
if trainer.hp_search_backend == HPSearchBackend.OPTUNA:
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
elif trainer.hp_search_backend == HPSearchBackend.RAY:
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.tb_writer
trainer.tb_writer = None
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
# `args.n_gpu` is considered the total number of GPUs that will be split
# among the `n_jobs`
n_jobs = int(kwargs.pop("n_jobs", 1))
num_gpus_per_trial = trainer.args.n_gpu
if num_gpus_per_trial / n_jobs >= 1:
num_gpus_per_trial = int(np.ceil(num_gpus_per_trial / n_jobs))
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}
if "reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:
from ray.tune.schedulers import (
ASHAScheduler,
HyperBandForBOHB,
MedianStoppingRule,
PopulationBasedTraining,
)
# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or not trainer.args.evaluate_during_training):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
trainer.tb_writer = _tb_writer
return best_run
...@@ -27,6 +27,7 @@ from .integrations import ( ...@@ -27,6 +27,7 @@ from .integrations import (
is_ray_available, is_ray_available,
is_tensorboard_available, is_tensorboard_available,
is_wandb_available, is_wandb_available,
run_hp_search,
) )
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
...@@ -295,6 +296,7 @@ class Trainer: ...@@ -295,6 +296,7 @@ class Trainer:
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None self.hp_search_backend = None
self.use_tune_checkpoints = False
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
...@@ -544,8 +546,21 @@ class Trainer: ...@@ -544,8 +546,21 @@ class Trainer:
if trial.should_prune(): if trial.should_prune():
raise optuna.TrialPruned() raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
if self.global_step % self.args.save_steps == 0:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics) tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self):
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
self.save_model(output_dir)
if self.is_world_master():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
""" """
Main training entry point. Main training entry point.
...@@ -869,40 +884,7 @@ class Trainer: ...@@ -869,40 +884,7 @@ class Trainer:
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
def _objective(trial): best_run = run_hp_search(self, n_trials, direction, kwargs)
self.objective = None
self.train(trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(self, "objective", None) is None:
metrics = self.evaluate()
self.objective = self.compute_objective(metrics)
if self.hp_search_backend == HPSearchBackend.RAY:
tune.report(objective=self.objective)
return self.objective
if self.hp_search_backend == HPSearchBackend.OPTUNA:
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
elif self.hp_search_backend == HPSearchBackend.RAY:
# The TensorBoard writer does not pickle so we have to remove it (if it exists) while doing the ray hp
# search.
_tb_writer = self.tb_writer
self.tb_writer = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and self.args.n_gpu > 0:
kwargs["resources_per_trial"] = {"gpu": self.args.n_gpu}
if "reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
analysis = tune.run(_objective, config=self.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
self.tb_writer = _tb_writer
self.hp_search_backend = None self.hp_search_backend = None
return best_run return best_run
......
...@@ -4,7 +4,6 @@ from typing import Any, Dict, NamedTuple, Optional ...@@ -4,7 +4,6 @@ from typing import Any, Dict, NamedTuple, Optional
import numpy as np import numpy as np
from .file_utils import is_tf_available, is_torch_available from .file_utils import is_tf_available, is_torch_available
from .integrations import is_ray_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
...@@ -93,6 +92,9 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: ...@@ -93,6 +92,9 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
def default_hp_space_optuna(trial) -> Dict[str, float]: def default_hp_space_optuna(trial) -> Dict[str, float]:
from .integrations import is_optuna_available
assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
return { return {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5), "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
...@@ -102,12 +104,14 @@ def default_hp_space_optuna(trial) -> Dict[str, float]: ...@@ -102,12 +104,14 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
def default_hp_space_ray(trial) -> Dict[str, float]: def default_hp_space_ray(trial) -> Dict[str, float]:
from .integrations import is_ray_available
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`" assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
from ray import tune from ray import tune
return { return {
"learning_rate": tune.loguniform(1e-6, 1e-4), "learning_rate": tune.loguniform(1e-6, 1e-4),
"num_train_epochs": tune.choice(range(1, 6)), "num_train_epochs": tune.choice(list(range(1, 6))),
"seed": tune.uniform(1, 40), "seed": tune.uniform(1, 40),
"per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]), "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment