Unverified Commit 7cb52f53 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix LR decay in TF Trainer (#5269)

* Recover old PR

* Apply style

* Trigger CI
parent 321c05ab
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
import math import math
import os import os
import random
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import numpy as np import numpy as np
...@@ -21,6 +22,12 @@ if is_wandb_available(): ...@@ -21,6 +22,12 @@ if is_wandb_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
class TFTrainer: class TFTrainer:
model: TFPreTrainedModel model: TFPreTrainedModel
args: TFTrainingArguments args: TFTrainingArguments
...@@ -59,6 +66,7 @@ class TFTrainer: ...@@ -59,6 +66,7 @@ class TFTrainer:
self.tb_writer = tb_writer self.tb_writer = tb_writer
else: else:
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir) self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
if is_wandb_available(): if is_wandb_available():
self._setup_wandb() self._setup_wandb()
else: else:
...@@ -67,6 +75,8 @@ class TFTrainer: ...@@ -67,6 +75,8 @@ class TFTrainer:
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface." "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
) )
set_seed(self.args.seed)
def get_train_tfdataset(self) -> tf.data.Dataset: def get_train_tfdataset(self) -> tf.data.Dataset:
if self.train_dataset is None: if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
...@@ -109,7 +119,7 @@ class TFTrainer: ...@@ -109,7 +119,7 @@ class TFTrainer:
return self.args.strategy.experimental_distribute_dataset(ds) return self.args.strategy.experimental_distribute_dataset(ds)
def get_optimizers( def get_optimizers(
self, self, num_training_steps: int,
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]: ) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
""" """
Setup the optimizer and the learning rate scheduler. Setup the optimizer and the learning rate scheduler.
...@@ -123,7 +133,7 @@ class TFTrainer: ...@@ -123,7 +133,7 @@ class TFTrainer:
optimizer, scheduler = create_optimizer( optimizer, scheduler = create_optimizer(
self.args.learning_rate, self.args.learning_rate,
self.train_steps, num_training_steps,
self.args.warmup_steps, self.args.warmup_steps,
adam_epsilon=self.args.adam_epsilon, adam_epsilon=self.args.adam_epsilon,
weight_decay_rate=self.args.weight_decay, weight_decay_rate=self.args.weight_decay,
...@@ -238,14 +248,19 @@ class TFTrainer: ...@@ -238,14 +248,19 @@ class TFTrainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _log(self, logs: Dict[str, float]) -> None: def _log(self, logs: Dict[str, float]) -> None:
logs["epoch"] = self.epoch_logging
if self.tb_writer: if self.tb_writer:
with self.tb_writer.as_default(): with self.tb_writer.as_default():
for k, v in logs.items(): for k, v in logs.items():
tf.summary.scalar(k, v, step=self.global_step) tf.summary.scalar(k, v, step=self.global_step)
self.tb_writer.flush() self.tb_writer.flush()
if is_wandb_available(): if is_wandb_available():
wandb.log(logs, step=self.global_step) wandb.log(logs, step=self.global_step)
output = {**logs, **{"step": self.global_step}} output = {**logs, **{"step": self.global_step}}
logger.info(output) logger.info(output)
def evaluate( def evaluate(
...@@ -260,6 +275,7 @@ class TFTrainer: ...@@ -260,6 +275,7 @@ class TFTrainer:
logs = {**output.metrics} logs = {**output.metrics}
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
self._log(logs) self._log(logs)
return output.metrics return output.metrics
...@@ -275,25 +291,45 @@ class TFTrainer: ...@@ -275,25 +291,45 @@ class TFTrainer:
self.gradient_accumulator.reset() self.gradient_accumulator.reset()
if self.args.max_steps > 0:
t_total = self.args.max_steps
steps_per_epoch = self.args.max_steps
else:
if self.args.dataloader_drop_last:
approx = math.floor
else:
approx = math.ceil
steps_per_epoch = approx(
self.num_train_examples / (self.args.train_batch_size * self.args.gradient_accumulation_steps)
)
t_total = steps_per_epoch * self.args.num_train_epochs
with self.args.strategy.scope(): with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers() optimizer, lr_scheduler = self.get_optimizers(num_training_steps=t_total)
iterations = optimizer.iterations iterations = optimizer.iterations
self.global_step = iterations.numpy()
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR) folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model) ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit) self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
if self.model.ckpt_manager.latest_checkpoint: if self.model.ckpt_manager.latest_checkpoint:
epochs_trained = self.global_step // (self.num_train_examples // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
self.num_train_examples // self.args.gradient_accumulation_steps
)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
logger.info( logger.info(
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint "Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
) )
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial() ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
else:
if iterations.numpy() > 0: epochs_trained = 1
logger.info("Start the training from the last checkpoint")
start_epoch = (iterations.numpy() // self.train_steps) + 1
else:
start_epoch = 1
tf.summary.experimental.set_step(iterations) tf.summary.experimental.set_step(iterations)
...@@ -311,17 +347,23 @@ class TFTrainer: ...@@ -311,17 +347,23 @@ class TFTrainer:
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples) logger.info(" Num examples = %d", self.num_train_examples)
logger.info(" Num Epochs = %d", epochs) logger.info(" Num Epochs = %d", epochs)
logger.info(" Total optimization steps = %d", self.train_steps) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size
)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
for epoch_iter in range(start_epoch, int(epochs + 1)): for epoch_iter in range(epochs_trained, int(epochs + 1)):
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)): for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
self.global_step = iterations.numpy() self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.train_steps self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
if self.args.debug: if self.args.debug:
logs = {} logs = {}
logs["loss"] = training_loss.numpy() logs["loss"] = training_loss.numpy()
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
self._log(logs) self._log(logs)
if self.global_step == 1 and self.args.debug: if self.global_step == 1 and self.args.debug:
...@@ -333,18 +375,23 @@ class TFTrainer: ...@@ -333,18 +375,23 @@ class TFTrainer:
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate() self.evaluate()
if self.global_step % self.args.logging_steps == 0: if (
self.global_step % self.args.logging_steps == 0
or self.global_step == 1
and self.args.logging_first_step
):
logs = {} logs = {}
logs["loss"] = training_loss.numpy() logs["loss"] = training_loss.numpy()
logs["learning_rate"] = lr_scheduler(self.global_step).numpy() logs["learning_rate"] = lr_scheduler(self.global_step).numpy()
logs["epoch"] = self.epoch_logging logs["epoch"] = self.epoch_logging
self._log(logs) self._log(logs)
if self.global_step % self.args.save_steps == 0: if self.global_step % self.args.save_steps == 0:
ckpt_save_path = self.model.ckpt_manager.save() ckpt_save_path = self.model.ckpt_manager.save()
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path)) logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
if self.global_step % self.train_steps == 0: if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
break break
def _training_steps(self, ds, optimizer): def _training_steps(self, ds, optimizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment