Unverified Commit 935330dd authored by OllieBroadhurst's avatar OllieBroadhurst Committed by GitHub
Browse files

Trainer evaluation delay (#16356)

* Initial commit

* Reversed signs, adjusted log entery.

* Check only when

* Cleanup checks

* Only trigger if we want to eval

* Run

* Move changes to callback
parent a220f160
......@@ -416,7 +416,7 @@ class DefaultFlowCallback(TrainerCallback):
control.should_log = True
# Evaluate
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0 and args.eval_delay > state.global_step:
control.should_evaluate = True
# Save
......@@ -439,7 +439,7 @@ class DefaultFlowCallback(TrainerCallback):
control.should_log = True
# Evaluate
if args.evaluation_strategy == IntervalStrategy.EPOCH:
if args.evaluation_strategy == IntervalStrategy.EPOCH and args.eval_delay > state.epoch:
control.should_evaluate = True
# Save
......
......@@ -138,6 +138,8 @@ class TrainingArguments:
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but
requires more memory).
eval_delay (`float`, *optional*):
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the evaluation_strategy.
learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate for [`AdamW`] optimizer.
weight_decay (`float`, *optional*, defaults to 0):
......@@ -472,6 +474,11 @@ class TrainingArguments:
metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
)
eval_delay: Optional[float] = field(
default=0,
metadata={"help": "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the evaluation_strategy."},
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment