Unverified Commit 8f609ab9 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

enable optuna multi-objectives feature (#25969)



* enable optuna multi-objectives feature
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update hpo doc

* update docstring
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* extend direction to List[str] type
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* Update src/transformers/integrations/integration_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 92f2fbad
...@@ -54,6 +54,18 @@ For optuna, see optuna [object_parameter](https://optuna.readthedocs.io/en/stabl ...@@ -54,6 +54,18 @@ For optuna, see optuna [object_parameter](https://optuna.readthedocs.io/en/stabl
... } ... }
``` ```
Optuna provides multi-objective HPO. You can pass `direction` in `hyperparameter_search` and define your own compute_objective to return multiple objective values. The Pareto Front (`List[BestRun]`) will be returned in hyperparameter_search, you should refer to the test case `TrainerHyperParameterMultiObjectOptunaIntegrationTest` in [test_trainer](https://github.com/huggingface/transformers/blob/main/tests/trainer/test_trainer.py). It's like following
```py
>>> best_trials = trainer.hyperparameter_search(
... direction=["minimize", "maximize"],
... backend="optuna",
... hp_space=optuna_hp_space,
... n_trials=20,
... compute_objective=compute_objective,
... )
```
For raytune, see raytune [object_parameter](https://docs.ray.io/en/latest/tune/api/search_space.html), it's like following: For raytune, see raytune [object_parameter](https://docs.ray.io/en/latest/tune/api/search_space.html), it's like following:
```py ```py
......
...@@ -205,10 +205,16 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -205,10 +205,16 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1) n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs) directions = direction if isinstance(direction, list) else None
direction = None if directions is not None else direction
study = optuna.create_study(direction=direction, directions=directions, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial if not study._is_multi_objective():
return BestRun(str(best_trial.number), best_trial.value, best_trial.params) best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
else:
best_trials = study.best_trials
return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
else: else:
for i in range(n_trials): for i in range(n_trials):
trainer.objective = None trainer.objective = None
......
...@@ -1233,10 +1233,11 @@ class Trainer: ...@@ -1233,10 +1233,11 @@ class Trainer:
if self.hp_search_backend == HPSearchBackend.OPTUNA: if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna import optuna
trial.report(self.objective, step) if not trial.study._is_multi_objective():
if trial.should_prune(): trial.report(self.objective, step)
self.callback_handler.on_train_end(self.args, self.state, self.control) if trial.should_prune():
raise optuna.TrialPruned() self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune from ray import tune
...@@ -2563,11 +2564,11 @@ class Trainer: ...@@ -2563,11 +2564,11 @@ class Trainer:
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
n_trials: int = 20, n_trials: int = 20,
direction: str = "minimize", direction: Union[str, List[str]] = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None, backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None, hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs, **kwargs,
) -> BestRun: ) -> Union[BestRun, List[BestRun]]:
""" """
Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
...@@ -2592,9 +2593,12 @@ class Trainer: ...@@ -2592,9 +2593,12 @@ class Trainer:
method. Will default to [`~trainer_utils.default_compute_objective`]. method. Will default to [`~trainer_utils.default_compute_objective`].
n_trials (`int`, *optional*, defaults to 100): n_trials (`int`, *optional*, defaults to 100):
The number of trial runs to test. The number of trial runs to test.
direction (`str`, *optional*, defaults to `"minimize"`): direction (`str` or `List[str]`, *optional*, defaults to `"minimize"`):
Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
`"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
several metrics. If it's multi objectives optimization, direction is `List[str]`, can be List of
`"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
`"maximize"` when optimizing one or several metrics.
backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
on which one is installed. If all are installed, will default to optuna. on which one is installed. If all are installed, will default to optuna.
...@@ -2610,8 +2614,9 @@ class Trainer: ...@@ -2610,8 +2614,9 @@ class Trainer:
- the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
Returns: Returns:
[`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in [`trainer_utils.BestRun` or `List[trainer_utils.BestRun]`]: All the information about the best run or best
`run_summary` attribute for Ray backend. runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
backend.
""" """
if backend is None: if backend is None:
backend = default_hp_search_backend() backend = default_hp_search_backend()
......
...@@ -215,7 +215,7 @@ class BestRun(NamedTuple): ...@@ -215,7 +215,7 @@ class BestRun(NamedTuple):
""" """
run_id: str run_id: str
objective: float objective: Union[float, List[float]]
hyperparameters: Dict[str, Any] hyperparameters: Dict[str, Any]
run_summary: Optional[Any] = None run_summary: Optional[Any] = None
......
...@@ -26,6 +26,7 @@ import tempfile ...@@ -26,6 +26,7 @@ import tempfile
import unittest import unittest
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Dict, List
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import numpy as np import numpy as np
...@@ -2310,6 +2311,62 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase): ...@@ -2310,6 +2311,62 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4) trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
@require_torch
@require_optuna
class TrainerHyperParameterMultiObjectOptunaIntegrationTest(unittest.TestCase):
def setUp(self):
args = TrainingArguments("..")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
def hp_space(trial):
return {}
def model_init(trial):
if trial is not None:
a = trial.suggest_int("a", -4, 4)
b = trial.suggest_int("b", -4, 4)
else:
a = 0
b = 0
config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(config)
def hp_name(trial):
return MyTrialShortNamer.shortname(trial.params)
def compute_objective(metrics: Dict[str, float]) -> List[float]:
return metrics["eval_loss"], metrics["eval_accuracy"]
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=IntervalStrategy.EPOCH,
save_strategy=IntervalStrategy.EPOCH,
num_train_epochs=10,
disable_tqdm=True,
load_best_model_at_end=True,
logging_dir="runs",
run_name="test",
model_init=model_init,
compute_metrics=AlmostAccuracy(),
)
trainer.hyperparameter_search(
direction=["minimize", "maximize"],
hp_space=hp_space,
hp_name=hp_name,
n_trials=4,
compute_objective=compute_objective,
)
@require_torch @require_torch
@require_ray @require_ray
class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment