Unverified Commit 7169d1ea authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Store FLOS as floats to avoid overflow. (#10213)

parent df1b0fb5
......@@ -959,7 +959,7 @@ class Trainer:
tr_loss += self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs)
self._total_flos += float(self.floating_point_ops(inputs))
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
......
......@@ -52,8 +52,9 @@ class TrainerState:
During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training.
total_flos (:obj:`float`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow).
log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`):
......@@ -76,7 +77,7 @@ class TrainerState:
global_step: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: int = 0
total_flos: float = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
......
......@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase):
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
trainer.train()
self.assertTrue(isinstance(trainer.state.total_flos, float))
@require_torch
@require_optuna
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment