"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "47df0f2234923e966d81b86d4d8c07d8727b66d4"
Unverified Commit 5fa66df3 authored by Justin Yu's avatar Justin Yu Committed by GitHub
Browse files

[integration] Update Ray Tune integration for Ray 2.7 (#26499)



* fix tune integration for ray 2.7+
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* add version check for ray tune backend availability
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* missing import
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* pin min version instead
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* address comments
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* some fixes
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* fix unnecessary final checkpoint
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* fix lint
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* dep table fix
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

* fix lint
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>

---------
Signed-off-by: default avatarJustin Yu <justinvyu@anyscale.com>
parent ffd426ee
...@@ -149,7 +149,7 @@ _deps = [ ...@@ -149,7 +149,7 @@ _deps = [
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"python>=3.8.0", "python>=3.8.0",
"ray[tune]", "ray[tune]>=2.7.0",
"regex!=2019.12.17", "regex!=2019.12.17",
"requests", "requests",
"rhoknp>=1.1.0,<1.3.1", "rhoknp>=1.1.0,<1.3.1",
......
...@@ -55,7 +55,7 @@ deps = { ...@@ -55,7 +55,7 @@ deps = {
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0", "python": "python>=3.8.0",
"ray[tune]": "ray[tune]", "ray[tune]": "ray[tune]>=2.7.0",
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"rhoknp": "rhoknp>=1.1.0,<1.3.1", "rhoknp": "rhoknp>=1.1.0,<1.3.1",
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .integrations import ( from .integrations import (
is_optuna_available, is_optuna_available,
is_ray_available, is_ray_tune_available,
is_sigopt_available, is_sigopt_available,
is_wandb_available, is_wandb_available,
run_hp_search_optuna, run_hp_search_optuna,
...@@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase): ...@@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):
@staticmethod @staticmethod
def is_available(): def is_available():
return is_ray_available() return is_ray_tune_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs): def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs) return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
......
...@@ -236,8 +236,9 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -236,8 +236,9 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray import ray
import ray.train
def _objective(trial, local_trainer, checkpoint_dir=None): def _objective(trial: dict, local_trainer):
try: try:
from transformers.utils.notebook import NotebookProgressCallback from transformers.utils.notebook import NotebookProgressCallback
...@@ -246,19 +247,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -246,19 +247,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
local_trainer.objective = None local_trainer.objective = None
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Upon trial resume, the local_trainer's objective gets reset to None.
# If `local_trainer.train` is a noop (training has already reached
# the target number of epochs/steps), then this would
# trigger an unnecessary extra checkpoint at the end of training.
# -> Set the objective to a dummy value upon resume as a workaround.
local_trainer.objective = "objective"
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
else:
local_trainer.train(trial=trial)
# If there hasn't been any evaluation during the training loop. # If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None: if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate() metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics) local_trainer.objective = local_trainer.compute_objective(metrics)
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True) metrics.update({"objective": local_trainer.objective, "done": True})
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)
if not trainer._memory_tracker.skip_memory_metrics: if not trainer._memory_tracker.skip_memory_metrics:
from ..trainer_utils import TrainerMemoryTracker from ..trainer_utils import TrainerMemoryTracker
...@@ -296,28 +312,10 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -296,28 +312,10 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
from ray.tune import CLIReporter from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"]) kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs: if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting. # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance( if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining) kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
......
...@@ -28,6 +28,7 @@ import random ...@@ -28,6 +28,7 @@ import random
import re import re
import shutil import shutil
import sys import sys
import tempfile
import time import time
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
...@@ -595,7 +596,6 @@ class Trainer: ...@@ -595,7 +596,6 @@ class Trainer:
# returned to 0 every time flos need to be logged # returned to 0 every time flos need to be logged
self.current_flos = 0 self.current_flos = 0
self.hp_search_backend = None self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__) default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__) self.can_return_loss = can_return_loss(self.model.__class__)
...@@ -1201,7 +1201,8 @@ class Trainer: ...@@ -1201,7 +1201,8 @@ class Trainer:
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None: if self.hp_search_backend is None or trial is None:
return return
self.objective = self.compute_objective(metrics.copy()) metrics = metrics.copy()
self.objective = self.compute_objective(metrics)
if self.hp_search_backend == HPSearchBackend.OPTUNA: if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna import optuna
...@@ -1211,18 +1212,17 @@ class Trainer: ...@@ -1211,18 +1212,17 @@ class Trainer:
self.callback_handler.on_train_end(self.args, self.state, self.control) self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned() raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune import ray.train
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if self.control.should_save: if self.control.should_save:
self._tune_save_checkpoint() self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
tune.report(objective=self.objective, **metrics) checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
metrics["objective"] = self.objective
ray.train.report(metrics, checkpoint=checkpoint)
def _tune_save_checkpoint(self): def _tune_save_checkpoint(self, checkpoint_dir: str):
from ray import tune
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True) self.save_model(output_dir, _internal_call=True)
if self.args.should_save: if self.args.should_save:
...@@ -2004,9 +2004,9 @@ class Trainer: ...@@ -2004,9 +2004,9 @@ class Trainer:
if self.hp_search_backend == HPSearchBackend.OPTUNA: if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune import ray.train
run_id = tune.get_trial_id() run_id = ray.train.get_context().get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT: elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB: elif self.hp_search_backend == HPSearchBackend.WANDB:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment