Unverified Commit e174bfeb authored by François Lagunas's avatar François Lagunas Committed by GitHub
Browse files

TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)

Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
parent bf162ce8
......@@ -85,6 +85,17 @@ def is_ray_available():
return _has_ray
def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
......@@ -192,6 +203,18 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
return best_run
def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
class TensorBoardCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
......@@ -208,17 +231,39 @@ class TensorBoardCallback(TrainerCallback):
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
self.tb_writer = tb_writer
def on_init_end(self, args, state, control, **kwargs):
if self.tb_writer is None and state.is_world_process_zero:
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
self.tb_writer = SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
log_dir = None
if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)
self._init_summary_writer(args, log_dir)
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
......@@ -249,7 +294,7 @@ class WandbCallback(TrainerCallback):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
self._initialized = False
def setup(self, args, state, model):
def setup(self, args, state, model, reinit, **kwargs):
"""
Setup the optional Weights & Biases (`wandb`) integration.
......@@ -271,21 +316,41 @@ class WandbCallback(TrainerCallback):
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if getattr(model, "config", None) is not None:
combined_dict = {**model.config.to_dict(), **combined_dict}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
reinit=reinit,
**init_args,
)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
self.setup(args, state, model, reinit=False)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
wandb.log(logs, step=state.global_step)
......
......@@ -20,6 +20,7 @@ from .file_utils import (
_torch_available,
_torch_tpu_available,
)
from .integrations import _has_optuna, _has_ray
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
......@@ -233,6 +234,32 @@ def require_faiss(test_case):
return test_case
def require_optuna(test_case):
"""
Decorator marking a test that requires optuna.
These tests are skipped when optuna isn't installed.
"""
if not _has_optuna:
return unittest.skip("test requires optuna")(test_case)
else:
return test_case
def require_ray(test_case):
"""
Decorator marking a test that requires Ray/tune.
These tests are skipped when Ray/tune isn't installed.
"""
if not _has_ray:
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case
def get_tests_dir(append_path=None):
"""
Args:
......
......@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
hp_params,
is_comet_available,
is_optuna_available,
is_ray_available,
......@@ -224,6 +225,7 @@ class Trainer:
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
self.hp_name = None
if model is None and model_init is not None:
model = self.call_model_init()
self.model = model.to(args.device) if model is not None else None
......@@ -508,8 +510,11 @@ class Trainer:
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
self._trial = trial
if self.hp_search_backend is None or trial is None:
return
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
for key, value in params.items():
if not hasattr(self.args, key):
......@@ -558,7 +563,10 @@ class Trainer:
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise Exception("model_init should have 0 or 1 argument.")
raise RuntimeError("model_init should have 0 or 1 argument.")
if model is None:
raise RuntimeError("model_init should not return None.")
return model
......@@ -617,6 +625,7 @@ class Trainer:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist
if (
......@@ -702,6 +711,8 @@ class Trainer:
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
self.state.trial_params = hp_params(trial) if trial is not None else None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
......@@ -783,13 +794,13 @@ class Trainer:
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
......@@ -823,7 +834,7 @@ class Trainer:
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
......@@ -842,6 +853,7 @@ class Trainer:
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save:
......@@ -857,12 +869,15 @@ class Trainer:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.store_flos()
self.save_model(output_dir)
# Save optimizer and scheduler
......@@ -909,6 +924,7 @@ class Trainer:
n_trials: int = 20,
direction: str = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs
) -> BestRun:
"""
......@@ -966,13 +982,13 @@ class Trainer:
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
......@@ -997,12 +1013,12 @@ class Trainer:
FutureWarning,
)
return self._log(logs)
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
......
......@@ -19,7 +19,7 @@ Callbacks to use with the Trainer class and customize the training loop.
import dataclasses
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from tqdm.auto import tqdm
......@@ -66,6 +66,9 @@ class TrainerState:
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
This will impact the way data will be logged in TensorBoard.
"""
epoch: Optional[float] = None
......@@ -78,6 +81,9 @@ class TrainerState:
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True
is_hyper_param_search: bool = False
trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None
def __post_init__(self):
if self.log_history is None:
......
......@@ -16,6 +16,7 @@
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
"""
import copy
import random
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
......@@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
Return:
:obj:`float`: The objective to minimize or maximize
"""
metrics = copy.deepcopy(metrics)
loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None)
_ = metrics.pop("total_flos", None)
return loss if len(metrics) == 0 else sum(metrics.values())
if len(metrics) != 0:
raise RuntimeError(
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
)
return loss
def default_hp_space_optuna(trial) -> Dict[str, float]:
......
import copy
import re
class TrialShortNamer:
PREFIX = "hp"
DEFAULTS = {}
NAMING_INFO = None
@classmethod
def set_defaults(cls, prefix, defaults):
cls.PREFIX = prefix
cls.DEFAULTS = defaults
cls.build_naming_info()
@staticmethod
def shortname_for_word(info, word):
if len(word) == 0:
return ""
short_word = None
if any(char.isdigit() for char in word):
raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
if word in info["short_word"]:
return info["short_word"][word]
for prefix_len in range(1, len(word) + 1):
prefix = word[:prefix_len]
if prefix in info["reverse_short_word"]:
continue
else:
short_word = prefix
break
if short_word is None:
# Paranoid fallback
def int_to_alphabetic(integer):
s = ""
while integer != 0:
s = chr(ord("A") + integer % 10) + s
integer //= 10
return s
i = 0
while True:
sword = word + "#" + int_to_alphabetic(i)
if sword in info["reverse_short_word"]:
continue
else:
short_word = sword
break
info["short_word"][word] = short_word
info["reverse_short_word"][short_word] = word
return short_word
@staticmethod
def shortname_for_key(info, param_name):
words = param_name.split("_")
shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]
# We try to create a separatorless short name, but if there is a collision we have to fallback
# to a separated short name
separators = ["", "_"]
for separator in separators:
shortname = separator.join(shortname_parts)
if shortname not in info["reverse_short_param"]:
info["short_param"][param_name] = shortname
info["reverse_short_param"][shortname] = param_name
return shortname
return param_name
@staticmethod
def add_new_param_name(info, param_name):
short_name = TrialShortNamer.shortname_for_key(info, param_name)
info["short_param"][param_name] = short_name
info["reverse_short_param"][short_name] = param_name
@classmethod
def build_naming_info(cls):
if cls.NAMING_INFO is not None:
return
info = dict(
short_word={},
reverse_short_word={},
short_param={},
reverse_short_param={},
)
field_keys = list(cls.DEFAULTS.keys())
for k in field_keys:
cls.add_new_param_name(info, k)
cls.NAMING_INFO = info
@classmethod
def shortname(cls, params):
cls.build_naming_info()
assert cls.PREFIX is not None
name = [copy.copy(cls.PREFIX)]
for k, v in params.items():
if k not in cls.DEFAULTS:
raise Exception(f"You should provide a default value for the param name {k} with value {v}")
if v == cls.DEFAULTS[k]:
# The default value is not added to the name
continue
key = cls.NAMING_INFO["short_param"][k]
if isinstance(v, bool):
v = 1 if v else 0
sep = "" if isinstance(v, (int, float)) else "-"
e = f"{key}{sep}{v}"
name.append(e)
return "_".join(name)
@classmethod
def parse_repr(cls, repr):
repr = repr[len(cls.PREFIX) + 1 :]
if repr == "":
values = []
else:
values = repr.split("_")
parameters = {}
for value in values:
if "-" in value:
p_k, p_v = value.split("-")
else:
p_k = re.sub("[0-9.]", "", value)
p_v = float(re.sub("[^0-9.]", "", value))
key = cls.NAMING_INFO["reverse_short_param"][p_k]
parameters[key] = p_v
for k in cls.DEFAULTS:
if k not in parameters:
parameters[k] = cls.DEFAULTS[k]
return parameters
......@@ -21,9 +21,17 @@ import unittest
import datasets
import numpy as np
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow
from transformers.testing_utils import (
get_tests_dir,
require_optuna,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
)
from transformers.utils.hp_naming import TrialShortNamer
if is_torch_available():
......@@ -142,6 +150,7 @@ if is_torch_available():
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression")
model_init = kwargs.pop("model_init", None)
args = TrainingArguments(output_dir, **kwargs)
return Trainer(
model,
......@@ -151,6 +160,7 @@ if is_torch_available():
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
optimizers=optimizers,
model_init=model_init,
)
......@@ -617,3 +627,46 @@ class TrainerIntegrationTest(unittest.TestCase):
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
@require_torch
@require_optuna
class TrainerHyperParameterIntegrationTest(unittest.TestCase):
def setUp(self):
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
def hp_space(trial):
return {}
def model_init(trial):
if trial is not None:
a = trial.suggest_int("a", -4, 4)
b = trial.suggest_int("b", -4, 4)
else:
a = 0
b = 0
config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(config)
def hp_name(trial):
return MyTrialShortNamer.shortname(trial.params)
trainer = get_regression_trainer(
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,
logging_dir="runs",
run_name="test",
model_init=model_init,
)
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment