Unverified Commit a7da2996 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix issue with ratio evaluation steps and auto find batch size (#25436)

* Fully rebased solution

* 500
parent 2d6839ea
...@@ -746,7 +746,7 @@ class WandbCallback(TrainerCallback): ...@@ -746,7 +746,7 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false") _watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps)) self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None: if self._wandb is None:
......
...@@ -1586,14 +1586,6 @@ class Trainer: ...@@ -1586,14 +1586,6 @@ class Trainer:
f" {args.max_steps}" f" {args.max_steps}"
) )
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps and args.logging_steps < 1:
args.logging_steps = math.ceil(max_steps * args.logging_steps)
if args.eval_steps and args.eval_steps < 1:
args.eval_steps = math.ceil(max_steps * args.eval_steps)
if args.save_steps and args.save_steps < 1:
args.save_steps = math.ceil(max_steps * args.save_steps)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module # nn.DataParallel(model) replicates the model, creating new variables and module
...@@ -1627,6 +1619,23 @@ class Trainer: ...@@ -1627,6 +1619,23 @@ class Trainer:
self.state = TrainerState() self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None self.state.is_hyper_param_search = trial is not None
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
# Activate gradient checkpointing if needed # Activate gradient checkpointing if needed
if args.gradient_checkpointing: if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable() self.model.gradient_checkpointing_enable()
......
...@@ -53,6 +53,12 @@ class TrainerState: ...@@ -53,6 +53,12 @@ class TrainerState:
During training, represents the number of update steps completed. During training, represents the number of update steps completed.
max_steps (`int`, *optional*, defaults to 0): max_steps (`int`, *optional*, defaults to 0):
The number of update steps to do during the current training. The number of update steps to do during the current training.
logging_steps (`int`, *optional*, defaults to 500):
Log every X updates steps
eval_steps (`int`, *optional*):
Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps.
total_flos (`float`, *optional*, defaults to 0): total_flos (`float`, *optional*, defaults to 0):
The total number of floating operations done by the model since the beginning of training (stored as floats The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow). to avoid overflow).
...@@ -77,6 +83,9 @@ class TrainerState: ...@@ -77,6 +83,9 @@ class TrainerState:
epoch: Optional[float] = None epoch: Optional[float] = None
global_step: int = 0 global_step: int = 0
max_steps: int = 0 max_steps: int = 0
logging_steps: int = 500
eval_steps: int = 500
save_steps: int = 500
num_train_epochs: int = 0 num_train_epochs: int = 0
total_flos: float = 0 total_flos: float = 0
log_history: List[Dict[str, float]] = None log_history: List[Dict[str, float]] = None
...@@ -421,13 +430,13 @@ class DefaultFlowCallback(TrainerCallback): ...@@ -421,13 +430,13 @@ class DefaultFlowCallback(TrainerCallback):
# Log # Log
if state.global_step == 1 and args.logging_first_step: if state.global_step == 1 and args.logging_first_step:
control.should_log = True control.should_log = True
if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % args.logging_steps == 0: if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
control.should_log = True control.should_log = True
# Evaluate # Evaluate
if ( if (
args.evaluation_strategy == IntervalStrategy.STEPS args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step % args.eval_steps == 0 and state.global_step % state.eval_steps == 0
and args.eval_delay <= state.global_step and args.eval_delay <= state.global_step
): ):
control.should_evaluate = True control.should_evaluate = True
...@@ -435,8 +444,8 @@ class DefaultFlowCallback(TrainerCallback): ...@@ -435,8 +444,8 @@ class DefaultFlowCallback(TrainerCallback):
# Save # Save
if ( if (
args.save_strategy == IntervalStrategy.STEPS args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0 and state.save_steps > 0
and state.global_step % args.save_steps == 0 and state.global_step % state.save_steps == 0
): ):
control.should_save = True control.should_save = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment