Unverified Commit 3a7fdd3f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add hyperparameter search to Trainer (#6576)



* Add optuna hyperparameter search to Trainer

* @julien-c suggestions
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>

* Make compute_objective an arg function

* Formatting

* Rework to make it easier to add ray

* Formatting

* Initial support for Ray

* Formatting

* Polish and finalize

* Add trial id to checkpoint with Ray

* Smaller default

* Use GPU in ray if available

* Formatting

* Fix test

* Update install instruction
Co-authored-by: default avatarRichard Liaw <rliaw@berkeley.edu>

* Address review comments

* Formatting post-merge
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
Co-authored-by: default avatarRichard Liaw <rliaw@berkeley.edu>
parent dd522da0
...@@ -92,7 +92,13 @@ from .file_utils import ( ...@@ -92,7 +92,13 @@ from .file_utils import (
from .hf_argparser import HfArgumentParser from .hf_argparser import HfArgumentParser
# Integrations # Integrations
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available from .integrations import (
is_comet_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
is_wandb_available,
)
# Model Cards # Model Cards
from .modelcard import ModelCard from .modelcard import ModelCard
......
...@@ -35,6 +35,20 @@ except ImportError: ...@@ -35,6 +35,20 @@ except ImportError:
except ImportError: except ImportError:
_has_tensorboard = False _has_tensorboard = False
try:
import optuna # noqa: F401
_has_optuna = True
except (ImportError):
_has_optuna = False
try:
import ray # noqa: F401
_has_ray = True
except (ImportError):
_has_ray = False
def is_wandb_available(): def is_wandb_available():
return _has_wandb return _has_wandb
...@@ -46,3 +60,18 @@ def is_comet_available(): ...@@ -46,3 +60,18 @@ def is_comet_available():
def is_tensorboard_available(): def is_tensorboard_available():
return _has_tensorboard return _has_tensorboard
def is_optuna_available():
return _has_optuna
def is_ray_available():
return _has_ray
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_available():
return "ray"
...@@ -21,10 +21,27 @@ from tqdm.auto import tqdm, trange ...@@ -21,10 +21,27 @@ from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, default_data_collator from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_nlp_available, is_torch_tpu_available from .file_utils import is_nlp_available, is_torch_tpu_available
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available from .integrations import (
default_hp_search_backend,
is_comet_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
is_wandb_available,
)
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, set_seed from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
HPSearchBackend,
PredictionOutput,
TrainOutput,
default_compute_objective,
default_hp_space,
set_seed,
)
from .training_args import TrainingArguments from .training_args import TrainingArguments
...@@ -62,6 +79,12 @@ if is_wandb_available(): ...@@ -62,6 +79,12 @@ if is_wandb_available():
if is_comet_available(): if is_comet_available():
import comet_ml import comet_ml
if is_optuna_available():
import optuna
if is_ray_available():
from ray import tune
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -140,10 +163,11 @@ class Trainer: ...@@ -140,10 +163,11 @@ class Trainer:
optimized for 🤗 Transformers. optimized for 🤗 Transformers.
Args: Args:
model (:class:`~transformers.PreTrainedModel`): model (:class:`~transformers.PreTrainedModel`, `optional`):
The model to train, evaluate or use for predictions. The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
args (:class:`~transformers.TrainingArguments`): args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`): data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
The function to use to form a batch from a list of elements of :obj:`train_dataset` or The function to use to form a batch from a list of elements of :obj:`train_dataset` or
:obj:`eval_dataset`. :obj:`eval_dataset`.
...@@ -153,6 +177,9 @@ class Trainer: ...@@ -153,6 +177,9 @@ class Trainer:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed.
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values. :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
...@@ -168,21 +195,31 @@ class Trainer: ...@@ -168,21 +195,31 @@ class Trainer:
def __init__( def __init__(
self, self,
model: PreTrainedModel, model: PreTrainedModel = None,
args: TrainingArguments, args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None, train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
tb_writer: Optional["SummaryWriter"] = None, tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs, **kwargs,
): ):
self.model = model.to(args.device) assert (
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
if model is None and model_init is not None:
model = model_init()
self.model = model.to(args.device) if model is not None else None
if args is None:
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
args = TrainingArguments("tmp_trainer")
self.args = args self.args = args
self.data_collator = data_collator if data_collator is not None else default_data_collator self.data_collator = data_collator if data_collator is not None else default_data_collator
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.model_init = model_init
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers self.optimizer, self.lr_scheduler = optimizers
self.tb_writer = tb_writer self.tb_writer = tb_writer
...@@ -242,6 +279,7 @@ class Trainer: ...@@ -242,6 +279,7 @@ class Trainer:
self.epoch = None self.epoch = None
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
...@@ -462,7 +500,38 @@ class Trainer: ...@@ -462,7 +500,38 @@ class Trainer:
""" """
return len(dataloader.dataset) return len(dataloader.dataset)
def train(self, model_path: Optional[str] = None): def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
if self.hp_search_backend is None or trial is None:
return
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
for key, value in params.items():
if not hasattr(self.args, key):
raise AttributeError(
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
)
old_attr = getattr(self.args, key, None)
# Casting value to the proper type
if old_attr is not None:
value = type(old_attr)(value)
setattr(self.args, key, value)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
logger.info("Trial:", trial.params)
def _report_to_hp_search(
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
trial.report(self.objective, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
tune.report(objective=self.objective, **metrics)
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
""" """
Main training entry point. Main training entry point.
...@@ -470,7 +539,17 @@ class Trainer: ...@@ -470,7 +539,17 @@ class Trainer:
model_path (:obj:`str`, `optional`): model_path (:obj:`str`, `optional`):
Local path to the model if the model to train has been instantiated from a local path. If present, Local path to the model if the model to train has been instantiated from a local path. If present,
training will resume from the optimizer/scheduler states loaded here. training will resume from the optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
""" """
# Model re-init
if self.model_init is not None:
model = self.model_init()
self.model = model.to(self.args.device)
self._hp_search_setup(trial)
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
if self.args.max_steps > 0: if self.args.max_steps > 0:
t_total = self.args.max_steps t_total = self.args.max_steps
...@@ -561,9 +640,8 @@ class Trainer: ...@@ -561,9 +640,8 @@ class Trainer:
tr_loss = 0.0 tr_loss = 0.0
logging_loss = 0.0 logging_loss = 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange( disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=not self.is_local_process_zero() train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
)
for epoch in train_iterator: for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
...@@ -572,9 +650,9 @@ class Trainer: ...@@ -572,9 +650,9 @@ class Trainer:
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device self.args.device
) )
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero()) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm)
else: else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero()) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm)
# Reset the past mems state at the beginning of each epoch if necessary. # Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0: if self.args.past_index >= 0:
...@@ -631,7 +709,8 @@ class Trainer: ...@@ -631,7 +709,8 @@ class Trainer:
self.log(logs) self.log(logs)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
...@@ -643,7 +722,15 @@ class Trainer: ...@@ -643,7 +722,15 @@ class Trainer:
else: else:
assert model is self.model, f"Model {model} should be a reference to self.model" assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = (
trial.number
if self.hp_search_backend == HPSearchBackend.OPTUNA
else tune.get_trial_id()
)
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.save_model(output_dir) self.save_model(output_dir)
...@@ -683,6 +770,108 @@ class Trainer: ...@@ -683,6 +770,108 @@ class Trainer:
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step) return TrainOutput(self.global_step, tr_loss / self.global_step)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
n_trials: int = 20,
timeout: int = 1800,
n_jobs: int = 1,
direction: str = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None,
**kwargs
) -> BestRun:
"""
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by the
method, which is the evaluation loss when no metric is provided, the sum of all metrics otherwise (you can
change that behavior by subclassing and overriding this method).
Args:
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
A function that defines the hyperparameter search space. Will default to
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
A function computing the objective to minimize or maximize from the metrics returned by the
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
n_trials (:obj:`int`, `optional`, defaults to 100):
The number of trial runs to test.
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
several metrics.
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
one is installed. If both are installed, will default to optuna.
kwargs:
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
more information see:
- the documentation of `optuna.create_stufy <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
- the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
Returns:
:class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
"""
if backend is None:
backend = default_hp_search_backend()
if backend is None:
raise RuntimeError(
"At least one of optuna or ray should be installed. "
"To install optuna run `pip install optuna`."
"To install ray run `pip install ray[tune]`."
)
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError(" You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_available():
raise RuntimeError(
" You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
self.hp_search_backend = backend
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
def _objective(trial):
# To make sure optimizer and lr_scheduler are reset with the new choices of HPs
self.optimizer = None
self.lr_scheduler = None
self.objective = None
self.train(trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(self, "objective", None) is None:
metrics = self.evaluate()
self.objective = self.compute_objective(metrics)
return self.objective
if self.hp_search_backend == HPSearchBackend.OPTUNA:
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
elif self.hp_search_backend == HPSearchBackend.RAY:
# The TensorBoard writer does not pickle so we have to remove it (if it exists) while doing the ray hp
# search.
_tb_writer = self.tb_writer
self.tb_writer = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and self.args.n_gpu > 0:
kwargs["resources_per_trial"] = {"gpu": self.args.n_gpu}
if "reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
analysis = tune.run(_objective, config=self.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
self.tb_writer = _tb_writer
self.hp_search_backend = None
return best_run
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
""" """
Log :obj:`logs` on the various objects watching training. Log :obj:`logs` on the various objects watching training.
...@@ -1020,8 +1209,9 @@ class Trainer: ...@@ -1020,8 +1209,9 @@ class Trainer:
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
samples_count = 0 samples_count = 0
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0] batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size samples_count += batch_size
......
import random import random
from typing import Dict, NamedTuple, Optional from typing import Any, Dict, NamedTuple, Optional
import numpy as np import numpy as np
from .file_utils import is_tf_available, is_torch_available from .file_utils import is_tf_available, is_torch_available
from .integrations import is_ray_available
from .tokenization_utils_base import ExplicitEnum
def set_seed(seed: int): def set_seed(seed: int):
...@@ -53,3 +55,70 @@ class TrainOutput(NamedTuple): ...@@ -53,3 +55,70 @@ class TrainOutput(NamedTuple):
PREFIX_CHECKPOINT_DIR = "checkpoint" PREFIX_CHECKPOINT_DIR = "checkpoint"
class BestRun(NamedTuple):
"""
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
Parameters:
run_id (:obj:`str`):
The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
with run-{run_id}).
objective (:obj:`float`):
The objective that was obtained for this run.
hyperparameters (:obj:`Dict[str, Any]`):
The hyperparameters picked to get this run.
"""
run_id: str
objective: float
hyperparameters: Dict[str, Any]
def default_compute_objective(metrics: Dict[str, float]) -> float:
"""
The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
metrics are provided to the :class:`~transformers.Trainer`, the sum of all metrics otherwise.
Args:
metrics (:obj:`Dict[str, float]`): The metrics returned by the evaluate method.
Return:
:obj:`float`: The objective to minimize or maximize
"""
loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None)
return loss if len(metrics) == 0 else sum(metrics.values())
def default_hp_space_optuna(trial) -> Dict[str, float]:
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
"seed": trial.suggest_int("seed", 1, 40),
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
}
def default_hp_space_ray(trial) -> Dict[str, float]:
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
from ray import tune
return {
"learning_rate": tune.loguniform(1e-6, 1e-4),
"num_train_epochs": tune.choice(range(1, 6)),
"seed": tune.uniform(1, 40),
"per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
}
class HPSearchBackend(ExplicitEnum):
OPTUNA = "optuna"
RAY = "ray"
default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
}
...@@ -114,6 +114,9 @@ class TrainingArguments: ...@@ -114,6 +114,9 @@ class TrainingArguments:
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`): run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging. A descriptor for the run. Notably used for wandb logging.
disable_tqdm (:obj:`bool`, `optional`):
Whether or not to disable the tqdm progress bars. Will default to :obj:`True` if the logging level is set
to warn or lower (default), :obj:`False` otherwise.
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`): remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
forward method. forward method.
...@@ -238,6 +241,13 @@ class TrainingArguments: ...@@ -238,6 +241,13 @@ class TrainingArguments:
run_name: Optional[str] = field( run_name: Optional[str] = field(
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."} default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
) )
disable_tqdm: Optional[bool] = field(
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
)
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
remove_unused_columns: Optional[bool] = field( remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment