Unverified Commit 7169d1ea authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Store FLOS as floats to avoid overflow. (#10213)

parent df1b0fb5
...@@ -959,7 +959,7 @@ class Trainer: ...@@ -959,7 +959,7 @@ class Trainer:
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
else: else:
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs) self._total_flos += float(self.floating_point_ops(inputs))
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
......
...@@ -52,8 +52,9 @@ class TrainerState: ...@@ -52,8 +52,9 @@ class TrainerState:
During training, represents the number of update steps completed. During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0): max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training. The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0): total_flos (:obj:`float`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training. The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow).
log_history (:obj:`List[Dict[str, float]]`, `optional`): log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training. The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`): best_metric (:obj:`float`, `optional`):
...@@ -76,7 +77,7 @@ class TrainerState: ...@@ -76,7 +77,7 @@ class TrainerState:
global_step: int = 0 global_step: int = 0
max_steps: int = 0 max_steps: int = 0
num_train_epochs: int = 0 num_train_epochs: int = 0
total_flos: int = 0 total_flos: float = 0
log_history: List[Dict[str, float]] = None log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None best_model_checkpoint: Optional[str] = None
......
...@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase):
# with enforced DataParallel # with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model)) assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
trainer.train()
self.assertTrue(isinstance(trainer.state.total_flos, float))
@require_torch @require_torch
@require_optuna @require_optuna
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment