"vscode:/vscode.git/clone" did not exist on "0554e4d5c56c11f90312c83389b4791475caf1b5"
Unverified Commit a7da2996 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix issue with ratio evaluation steps and auto find batch size (#25436)

* Fully rebased solution

* 500
parent 2d6839ea
......@@ -746,7 +746,7 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps))
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
......
......@@ -1586,14 +1586,6 @@ class Trainer:
f" {args.max_steps}"
)
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps and args.logging_steps < 1:
args.logging_steps = math.ceil(max_steps * args.logging_steps)
if args.eval_steps and args.eval_steps < 1:
args.eval_steps = math.ceil(max_steps * args.eval_steps)
if args.save_steps and args.save_steps < 1:
args.save_steps = math.ceil(max_steps * args.save_steps)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
......@@ -1627,6 +1619,23 @@ class Trainer:
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
......
......@@ -53,6 +53,12 @@ class TrainerState:
During training, represents the number of update steps completed.
max_steps (`int`, *optional*, defaults to 0):
The number of update steps to do during the current training.
logging_steps (`int`, *optional*, defaults to 500):
Log every X updates steps
eval_steps (`int`, *optional*):
Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps.
total_flos (`float`, *optional*, defaults to 0):
The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow).
......@@ -77,6 +83,9 @@ class TrainerState:
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
logging_steps: int = 500
eval_steps: int = 500
save_steps: int = 500
num_train_epochs: int = 0
total_flos: float = 0
log_history: List[Dict[str, float]] = None
......@@ -421,13 +430,13 @@ class DefaultFlowCallback(TrainerCallback):
# Log
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % args.logging_steps == 0:
if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
control.should_log = True
# Evaluate
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step % args.eval_steps == 0
and state.global_step % state.eval_steps == 0
and args.eval_delay <= state.global_step
):
control.should_evaluate = True
......@@ -435,8 +444,8 @@ class DefaultFlowCallback(TrainerCallback):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
and state.save_steps > 0
and state.global_step % state.save_steps == 0
):
control.should_save = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment