Unverified Commit b6295b26 authored by Alex Hall's avatar Alex Hall Committed by GitHub
Browse files

Refactor hyperparameter search backends (#24384)

* Refactor hyperparameter search backends

* Simpler refactoring without abstract base class

* black

* review comments:
specify name in class
use methods instead of callable class attributes
name constant better

* review comments: safer bool checking, log multiple available backends

* test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black.

* copyright
parent a1c4b630
...@@ -98,6 +98,7 @@ _import_structure = { ...@@ -98,6 +98,7 @@ _import_structure = {
"file_utils": [], "file_utils": [],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"], "generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"hf_argparser": ["HfArgumentParser"], "hf_argparser": ["HfArgumentParser"],
"hyperparameter_search": [],
"image_transforms": [], "image_transforms": [],
"integrations": [ "integrations": [
"is_clearml_available", "is_clearml_available",
......
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .integrations import (
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)
from .trainer_utils import (
HPSearchBackend,
default_hp_space_optuna,
default_hp_space_ray,
default_hp_space_sigopt,
default_hp_space_wandb,
)
from .utils import logging
logger = logging.get_logger(__name__)
class HyperParamSearchBackendBase:
name: str
pip_package: str = None
def is_available(self):
raise NotImplementedError
def run(self, trainer, n_trials: int, direction: str, **kwargs):
raise NotImplementedError
def default_hp_space(self, trial):
raise NotImplementedError
def ensure_available(self):
if not self.is_available():
raise RuntimeError(
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
)
@classmethod
def pip_install(cls):
return f"`pip install {cls.pip_package or cls.name}`"
class OptunaBackend(HyperParamSearchBackendBase):
name = "optuna"
def is_available(self):
return is_optuna_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_optuna(trial)
class RayTuneBackend(HyperParamSearchBackendBase):
name = "ray"
pip_package = "'ray[tune]'"
def is_available(self):
return is_ray_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_ray(trial)
class SigOptBackend(HyperParamSearchBackendBase):
name = "sigopt"
def is_available(self):
return is_sigopt_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_sigopt(trial)
class WandbBackend(HyperParamSearchBackendBase):
name = "wandb"
def is_available(self):
return is_wandb_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_wandb(trial)
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
}
def default_hp_search_backend() -> str:
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
if len(available_backends) > 0:
name = available_backends[0].name
if len(available_backends) > 1:
logger.info(
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
)
return name
raise RuntimeError(
"No hyperparameter search backend available.\n"
+ "\n".join(
f" - To install {backend.name} run {backend.pip_install()}"
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
)
)
...@@ -177,15 +177,6 @@ def hp_params(trial): ...@@ -177,15 +177,6 @@ def hp_params(trial):
raise RuntimeError(f"Unknown type for trial {trial.__class__}") raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_tune_available():
return "ray"
elif is_sigopt_available():
return "sigopt"
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna import optuna
......
...@@ -36,18 +36,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un ...@@ -36,18 +36,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
# Integrations must be imported before ML frameworks: # Integrations must be imported before ML frameworks:
# isort: off # isort: off
from .integrations import ( from .integrations import (
default_hp_search_backend,
get_reporting_integration_callbacks, get_reporting_integration_callbacks,
hp_params, hp_params,
is_fairscale_available, is_fairscale_available,
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
) )
# isort: on # isort: on
...@@ -66,6 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d ...@@ -66,6 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
...@@ -114,7 +106,6 @@ from .trainer_utils import ( ...@@ -114,7 +106,6 @@ from .trainer_utils import (
TrainerMemoryTracker, TrainerMemoryTracker,
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space,
denumpify_detensorize, denumpify_detensorize,
enable_full_determinism, enable_full_determinism,
find_executable_batch_size, find_executable_batch_size,
...@@ -2517,41 +2508,20 @@ class Trainer: ...@@ -2517,41 +2508,20 @@ class Trainer:
""" """
if backend is None: if backend is None:
backend = default_hp_search_backend() backend = default_hp_search_backend()
if backend is None:
raise RuntimeError(
"At least one of optuna or ray should be installed. "
"To install optuna run `pip install optuna`. "
"To install ray run `pip install ray[tune]`. "
"To install sigopt run `pip install sigopt`."
)
backend = HPSearchBackend(backend) backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") backend_obj.ensure_available()
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
if backend == HPSearchBackend.WANDB and not is_wandb_available():
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
self.hp_search_backend = backend self.hp_search_backend = backend
if self.model_init is None: if self.model_init is None:
raise RuntimeError( raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function." "To use hyperparameter search, you need to pass your model through a model_init function."
) )
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
self.hp_name = hp_name self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
backend_dict = { best_run = backend_obj.run(self, n_trials, direction, **kwargs)
HPSearchBackend.OPTUNA: run_hp_search_optuna,
HPSearchBackend.RAY: run_hp_search_ray,
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
HPSearchBackend.WANDB: run_hp_search_wandb,
}
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
self.hp_search_backend = None self.hp_search_backend = None
return best_run return best_run
......
...@@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum): ...@@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum):
WANDB = "wandb" WANDB = "wandb"
default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
HPSearchBackend.WANDB: default_hp_space_wandb,
}
def is_main_process(local_rank): def is_main_process(local_rank):
""" """
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
......
...@@ -42,6 +42,7 @@ from transformers import ( ...@@ -42,6 +42,7 @@ from transformers import (
is_torch_available, is_torch_available,
logging, logging,
) )
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING, ENDPOINT_STAGING,
TOKEN, TOKEN,
...@@ -72,7 +73,7 @@ from transformers.testing_utils import ( ...@@ -72,7 +73,7 @@ from transformers.testing_utils import (
require_wandb, require_wandb,
slow, slow,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils import ( from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
...@@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase): ...@@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search( trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must" direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
) )
class HyperParameterSearchBackendsTest(unittest.TestCase):
def test_hyperparameter_search_backends(self):
self.assertEqual(
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
list(HPSearchBackend),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment