"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f80d631b371a075e0a36d22f7ce6b150f05943ef"
Unverified Commit e174bfeb authored by François Lagunas's avatar François Lagunas Committed by GitHub
Browse files

TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)

Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
parent bf162ce8
...@@ -85,6 +85,17 @@ def is_ray_available(): ...@@ -85,6 +85,17 @@ def is_ray_available():
return _has_ray return _has_ray
def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def default_hp_search_backend(): def default_hp_search_backend():
if is_optuna_available(): if is_optuna_available():
return "optuna" return "optuna"
...@@ -192,6 +203,18 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -192,6 +203,18 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
return best_run return best_run
def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
class TensorBoardCallback(TrainerCallback): class TensorBoardCallback(TrainerCallback):
""" """
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
...@@ -208,17 +231,39 @@ class TensorBoardCallback(TrainerCallback): ...@@ -208,17 +231,39 @@ class TensorBoardCallback(TrainerCallback):
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX." ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
self.tb_writer = tb_writer self.tb_writer = tb_writer
def on_init_end(self, args, state, control, **kwargs): def _init_summary_writer(self, args, log_dir=None):
if self.tb_writer is None and state.is_world_process_zero: log_dir = log_dir or args.logging_dir
self.tb_writer = SummaryWriter(log_dir=args.logging_dir) self.tb_writer = SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs): def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
log_dir = None
if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)
self._init_summary_writer(args, log_dir)
if self.tb_writer is not None: if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string()) self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer: if self.tb_writer:
logs = rewrite_logs(logs)
for k, v in logs.items(): for k, v in logs.items():
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step) self.tb_writer.add_scalar(k, v, state.global_step)
...@@ -249,7 +294,7 @@ class WandbCallback(TrainerCallback): ...@@ -249,7 +294,7 @@ class WandbCallback(TrainerCallback):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`." assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
self._initialized = False self._initialized = False
def setup(self, args, state, model): def setup(self, args, state, model, reinit, **kwargs):
""" """
Setup the optional Weights & Biases (`wandb`) integration. Setup the optional Weights & Biases (`wandb`) integration.
...@@ -271,21 +316,41 @@ class WandbCallback(TrainerCallback): ...@@ -271,21 +316,41 @@ class WandbCallback(TrainerCallback):
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
) )
combined_dict = {**args.to_sanitized_dict()} combined_dict = {**args.to_sanitized_dict()}
if getattr(model, "config", None) is not None:
combined_dict = {**model.config.to_dict(), **combined_dict} if hasattr(model, "config") and model.config is not None:
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name) model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
reinit=reinit,
**init_args,
)
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized: hp_search = state.is_hyper_param_search
self.setup(args, state, model) if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs)
def on_log(self, args, state, control, model=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized: if not self._initialized:
self.setup(args, state, model) self.setup(args, state, model, reinit=False)
if state.is_world_process_zero: if state.is_world_process_zero:
logs = rewrite_logs(logs)
wandb.log(logs, step=state.global_step) wandb.log(logs, step=state.global_step)
......
...@@ -20,6 +20,7 @@ from .file_utils import ( ...@@ -20,6 +20,7 @@ from .file_utils import (
_torch_available, _torch_available,
_torch_tpu_available, _torch_tpu_available,
) )
from .integrations import _has_optuna, _has_ray
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
...@@ -233,6 +234,32 @@ def require_faiss(test_case): ...@@ -233,6 +234,32 @@ def require_faiss(test_case):
return test_case return test_case
def require_optuna(test_case):
"""
Decorator marking a test that requires optuna.
These tests are skipped when optuna isn't installed.
"""
if not _has_optuna:
return unittest.skip("test requires optuna")(test_case)
else:
return test_case
def require_ray(test_case):
"""
Decorator marking a test that requires Ray/tune.
These tests are skipped when Ray/tune isn't installed.
"""
if not _has_ray:
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case
def get_tests_dir(append_path=None): def get_tests_dir(append_path=None):
""" """
Args: Args:
......
...@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d ...@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .integrations import ( from .integrations import (
default_hp_search_backend, default_hp_search_backend,
hp_params,
is_comet_available, is_comet_available,
is_optuna_available, is_optuna_available,
is_ray_available, is_ray_available,
...@@ -224,6 +225,7 @@ class Trainer: ...@@ -224,6 +225,7 @@ class Trainer:
model is not None or model_init is not None model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument." ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init self.model_init = model_init
self.hp_name = None
if model is None and model_init is not None: if model is None and model_init is not None:
model = self.call_model_init() model = self.call_model_init()
self.model = model.to(args.device) if model is not None else None self.model = model.to(args.device) if model is not None else None
...@@ -508,8 +510,11 @@ class Trainer: ...@@ -508,8 +510,11 @@ class Trainer:
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """ """ HP search setup code """
self._trial = trial
if self.hp_search_backend is None or trial is None: if self.hp_search_backend is None or trial is None:
return return
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
for key, value in params.items(): for key, value in params.items():
if not hasattr(self.args, key): if not hasattr(self.args, key):
...@@ -558,7 +563,10 @@ class Trainer: ...@@ -558,7 +563,10 @@ class Trainer:
elif model_init_argcount == 1: elif model_init_argcount == 1:
model = self.model_init(trial) model = self.model_init(trial)
else: else:
raise Exception("model_init should have 0 or 1 argument.") raise RuntimeError("model_init should have 0 or 1 argument.")
if model is None:
raise RuntimeError("model_init should not return None.")
return model return model
...@@ -617,6 +625,7 @@ class Trainer: ...@@ -617,6 +625,7 @@ class Trainer:
self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState() self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
if ( if (
...@@ -702,6 +711,8 @@ class Trainer: ...@@ -702,6 +711,8 @@ class Trainer:
self.callback_handler.optimizer = self.optimizer self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader self.callback_handler.train_dataloader = train_dataloader
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
self.state.trial_params = hp_params(trial) if trial is not None else None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer # This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load. # to set this after the load.
self.state.max_steps = max_steps self.state.max_steps = max_steps
...@@ -783,13 +794,13 @@ class Trainer: ...@@ -783,13 +794,13 @@ class Trainer:
self.state.epoch = epoch + (step + 1) / steps_in_epoch self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.control.should_epoch_stop or self.control.should_training_stop: if self.control.should_epoch_stop or self.control.should_training_stop:
break break
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -823,7 +834,7 @@ class Trainer: ...@@ -823,7 +834,7 @@ class Trainer:
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch): def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log: if self.control.should_log:
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item() tr_loss_scalar = tr_loss.item()
...@@ -842,6 +853,7 @@ class Trainer: ...@@ -842,6 +853,7 @@ class Trainer:
if self.control.should_evaluate: if self.control.should_evaluate:
metrics = self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics) self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save: if self.control.should_save:
...@@ -857,9 +869,12 @@ class Trainer: ...@@ -857,9 +869,12 @@ class Trainer:
assert model is self.model, f"Model {model} should be a reference to self.model" assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint # Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}" run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder) output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos() self.store_flos()
...@@ -909,6 +924,7 @@ class Trainer: ...@@ -909,6 +924,7 @@ class Trainer:
n_trials: int = 20, n_trials: int = 20,
direction: str = "minimize", direction: str = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None, backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs **kwargs
) -> BestRun: ) -> BestRun:
""" """
...@@ -966,13 +982,13 @@ class Trainer: ...@@ -966,13 +982,13 @@ class Trainer:
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
) )
self.hp_search_backend = backend self.hp_search_backend = backend
if self.model_init is None: if self.model_init is None:
raise RuntimeError( raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function." "To use hyperparameter search, you need to pass your model through a model_init function."
) )
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
...@@ -997,12 +1013,12 @@ class Trainer: ...@@ -997,12 +1013,12 @@ class Trainer:
FutureWarning, FutureWarning,
) )
return self._log(logs) return self._log(logs)
if self.state.epoch is not None: if self.state.epoch is not None:
logs["epoch"] = self.state.epoch logs["epoch"] = self.state.epoch
if self._total_flos is not None: if self._total_flos is not None:
self.store_flos() self.store_flos()
logs["total_flos"] = self.state.total_flos logs["total_flos"] = self.state.total_flos
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}} output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output) self.state.log_history.append(output)
......
...@@ -19,7 +19,7 @@ Callbacks to use with the Trainer class and customize the training loop. ...@@ -19,7 +19,7 @@ Callbacks to use with the Trainer class and customize the training loop.
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -66,6 +66,9 @@ class TrainerState: ...@@ -66,6 +66,9 @@ class TrainerState:
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the global main process (when training in a distributed fashion on Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process). several machines, this is only going to be :obj:`True` for one process).
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
This will impact the way data will be logged in TensorBoard.
""" """
epoch: Optional[float] = None epoch: Optional[float] = None
...@@ -78,6 +81,9 @@ class TrainerState: ...@@ -78,6 +81,9 @@ class TrainerState:
best_model_checkpoint: Optional[str] = None best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True is_local_process_zero: bool = True
is_world_process_zero: bool = True is_world_process_zero: bool = True
is_hyper_param_search: bool = False
trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None
def __post_init__(self): def __post_init__(self):
if self.log_history is None: if self.log_history is None:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow. Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
""" """
import copy
import random import random
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
...@@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: ...@@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
Return: Return:
:obj:`float`: The objective to minimize or maximize :obj:`float`: The objective to minimize or maximize
""" """
metrics = copy.deepcopy(metrics)
loss = metrics.pop("eval_loss", None) loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
_ = metrics.pop("total_flos", None) _ = metrics.pop("total_flos", None)
return loss if len(metrics) == 0 else sum(metrics.values()) if len(metrics) != 0:
raise RuntimeError(
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
)
return loss
def default_hp_space_optuna(trial) -> Dict[str, float]: def default_hp_space_optuna(trial) -> Dict[str, float]:
......
import copy
import re
class TrialShortNamer:
PREFIX = "hp"
DEFAULTS = {}
NAMING_INFO = None
@classmethod
def set_defaults(cls, prefix, defaults):
cls.PREFIX = prefix
cls.DEFAULTS = defaults
cls.build_naming_info()
@staticmethod
def shortname_for_word(info, word):
if len(word) == 0:
return ""
short_word = None
if any(char.isdigit() for char in word):
raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
if word in info["short_word"]:
return info["short_word"][word]
for prefix_len in range(1, len(word) + 1):
prefix = word[:prefix_len]
if prefix in info["reverse_short_word"]:
continue
else:
short_word = prefix
break
if short_word is None:
# Paranoid fallback
def int_to_alphabetic(integer):
s = ""
while integer != 0:
s = chr(ord("A") + integer % 10) + s
integer //= 10
return s
i = 0
while True:
sword = word + "#" + int_to_alphabetic(i)
if sword in info["reverse_short_word"]:
continue
else:
short_word = sword
break
info["short_word"][word] = short_word
info["reverse_short_word"][short_word] = word
return short_word
@staticmethod
def shortname_for_key(info, param_name):
words = param_name.split("_")
shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]
# We try to create a separatorless short name, but if there is a collision we have to fallback
# to a separated short name
separators = ["", "_"]
for separator in separators:
shortname = separator.join(shortname_parts)
if shortname not in info["reverse_short_param"]:
info["short_param"][param_name] = shortname
info["reverse_short_param"][shortname] = param_name
return shortname
return param_name
@staticmethod
def add_new_param_name(info, param_name):
short_name = TrialShortNamer.shortname_for_key(info, param_name)
info["short_param"][param_name] = short_name
info["reverse_short_param"][short_name] = param_name
@classmethod
def build_naming_info(cls):
if cls.NAMING_INFO is not None:
return
info = dict(
short_word={},
reverse_short_word={},
short_param={},
reverse_short_param={},
)
field_keys = list(cls.DEFAULTS.keys())
for k in field_keys:
cls.add_new_param_name(info, k)
cls.NAMING_INFO = info
@classmethod
def shortname(cls, params):
cls.build_naming_info()
assert cls.PREFIX is not None
name = [copy.copy(cls.PREFIX)]
for k, v in params.items():
if k not in cls.DEFAULTS:
raise Exception(f"You should provide a default value for the param name {k} with value {v}")
if v == cls.DEFAULTS[k]:
# The default value is not added to the name
continue
key = cls.NAMING_INFO["short_param"][k]
if isinstance(v, bool):
v = 1 if v else 0
sep = "" if isinstance(v, (int, float)) else "-"
e = f"{key}{sep}{v}"
name.append(e)
return "_".join(name)
@classmethod
def parse_repr(cls, repr):
repr = repr[len(cls.PREFIX) + 1 :]
if repr == "":
values = []
else:
values = repr.split("_")
parameters = {}
for value in values:
if "-" in value:
p_k, p_v = value.split("-")
else:
p_k = re.sub("[0-9.]", "", value)
p_v = float(re.sub("[^0-9.]", "", value))
key = cls.NAMING_INFO["reverse_short_param"][p_k]
parameters[key] = p_v
for k in cls.DEFAULTS:
if k not in parameters:
parameters[k] = cls.DEFAULTS[k]
return parameters
...@@ -21,9 +21,17 @@ import unittest ...@@ -21,9 +21,17 @@ import unittest
import datasets import datasets
import numpy as np import numpy as np
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow from transformers.testing_utils import (
get_tests_dir,
require_optuna,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
)
from transformers.utils.hp_naming import TrialShortNamer
if is_torch_available(): if is_torch_available():
...@@ -142,6 +150,7 @@ if is_torch_available(): ...@@ -142,6 +150,7 @@ if is_torch_available():
data_collator = kwargs.pop("data_collator", None) data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None)) optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression") output_dir = kwargs.pop("output_dir", "./regression")
model_init = kwargs.pop("model_init", None)
args = TrainingArguments(output_dir, **kwargs) args = TrainingArguments(output_dir, **kwargs)
return Trainer( return Trainer(
model, model,
...@@ -151,6 +160,7 @@ if is_torch_available(): ...@@ -151,6 +160,7 @@ if is_torch_available():
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
optimizers=optimizers, optimizers=optimizers,
model_init=model_init,
) )
...@@ -617,3 +627,46 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -617,3 +627,46 @@ class TrainerIntegrationTest(unittest.TestCase):
# with enforced DataParallel # with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model)) assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
@require_torch
@require_optuna
class TrainerHyperParameterIntegrationTest(unittest.TestCase):
def setUp(self):
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
def hp_space(trial):
return {}
def model_init(trial):
if trial is not None:
a = trial.suggest_int("a", -4, 4)
b = trial.suggest_int("b", -4, 4)
else:
a = 0
b = 0
config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(config)
def hp_name(trial):
return MyTrialShortNamer.shortname(trial.params)
trainer = get_regression_trainer(
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,
logging_dir="runs",
run_name="test",
model_init=model_init,
)
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment