Commit cafa6a9e authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[Trainer] Ability to specify optimizer/scheduler at init

cc @patrickvonplaten @thomwolf
parent e4fd5e39
...@@ -113,6 +113,7 @@ class Trainer: ...@@ -113,6 +113,7 @@ class Trainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool prediction_loss_only: bool
tb_writer: Optional["SummaryWriter"] = None tb_writer: Optional["SummaryWriter"] = None
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
def __init__( def __init__(
self, self,
...@@ -124,6 +125,7 @@ class Trainer: ...@@ -124,6 +125,7 @@ class Trainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False, prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None, tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
): ):
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, Trainer is a simple but feature-complete training and eval loop for PyTorch,
...@@ -143,6 +145,7 @@ class Trainer: ...@@ -143,6 +145,7 @@ class Trainer:
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers
if tb_writer is not None: if tb_writer is not None:
self.tb_writer = tb_writer self.tb_writer = tb_writer
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]: elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
...@@ -227,6 +230,15 @@ class Trainer: ...@@ -227,6 +230,15 @@ class Trainer:
def get_optimizers( def get_optimizers(
self, num_training_steps: int self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well.
If you want to use something else, you can pass a tuple in the Trainer's init,
or override this method in a subclass.
"""
if self.optimizers is not None:
return self.optimizers
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment