Unverified Commit 5fa66df3 authored by Justin Yu's avatar Justin Yu Committed by GitHub
Browse files

[integration] Update Ray Tune integration for Ray 2.7 (#26499)



* fix tune integration for ray 2.7+
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* add version check for ray tune backend availability
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* missing import
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* pin min version instead
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* address comments
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* some fixes
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* fix unnecessary final checkpoint
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* fix lint
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* dep table fix
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* fix lint
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

---------
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>
parent ffd426ee
......@@ -149,7 +149,7 @@ _deps = [
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"ray[tune]",
"ray[tune]>=2.7.0",
"regex!=2019.12.17",
"requests",
"rhoknp>=1.1.0,<1.3.1",
......
......@@ -55,7 +55,7 @@ deps = {
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ray[tune]": "ray[tune]",
"ray[tune]": "ray[tune]>=2.7.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
......
......@@ -15,7 +15,7 @@
from .integrations import (
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
......@@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):
@staticmethod
def is_available():
return is_ray_available()
return is_ray_tune_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
......
......@@ -236,8 +236,9 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
import ray.train
def _objective(trial, local_trainer, checkpoint_dir=None):
def _objective(trial: dict, local_trainer):
try:
from transformers.utils.notebook import NotebookProgressCallback
......@@ -246,19 +247,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
except ModuleNotFoundError:
pass
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
local_trainer.objective = None
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Upon trial resume, the local_trainer's objective gets reset to None.
# If `local_trainer.train` is a noop (training has already reached
# the target number of epochs/steps), then this would
# trigger an unnecessary extra checkpoint at the end of training.
# -> Set the objective to a dummy value upon resume as a workaround.
local_trainer.objective = "objective"
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
else:
local_trainer.train(trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
metrics.update({"objective": local_trainer.objective, "done": True})
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)
if not trainer._memory_tracker.skip_memory_metrics:
from ..trainer_utils import TrainerMemoryTracker
......@@ -296,28 +312,10 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
......
......@@ -28,6 +28,7 @@ import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
......@@ -595,7 +596,6 @@ class Trainer:
# returned to 0 every time flos need to be logged
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
......@@ -1201,7 +1201,8 @@ class Trainer:
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics.copy())
metrics = metrics.copy()
self.objective = self.compute_objective(metrics)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
......@@ -1211,24 +1212,23 @@ class Trainer:
self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
if self.control.should_save:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self):
from ray import tune
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
import ray.train
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if self.control.should_save:
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
metrics["objective"] = self.objective
ray.train.report(metrics, checkpoint=checkpoint)
def _tune_save_checkpoint(self, checkpoint_dir: str):
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
def call_model_init(self, trial=None):
model_init_argcount = number_of_arguments(self.model_init)
......@@ -2004,9 +2004,9 @@ class Trainer:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
import ray.train
run_id = tune.get_trial_id()
run_id = ray.train.get_context().get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment