"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "873d9bb3ccc5758495201d739ceecac3a7c6e752"
Unverified Commit 037bdf82 authored by Masatoshi TSUCHIYA's avatar Masatoshi TSUCHIYA Committed by GitHub
Browse files

Refer warmup_ratio when setting warmup_num_steps. (#12818)

* Refer warmup_ratio when setting warmup_num_steps.

* Add a method to get number of warmup steps to TrainerArguments class.

* Fix.

* Fix.
parent 15d19ecf
......@@ -204,7 +204,6 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
self.fill_match("scheduler.params.warmup_num_steps", args.warmup_steps, "warmup_steps")
# total_num_steps - will get set in trainer_config_finalize
# fp16
......@@ -245,6 +244,7 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
# scheduler
self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)")
self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps")
if len(self.mismatches) > 0:
mismatches = "\n".join(self.mismatches)
......
......@@ -822,16 +822,10 @@ class Trainer:
num_training_steps (int): The number of training steps to do.
"""
if self.lr_scheduler is None:
warmup_steps = (
self.args.warmup_steps
if self.args.warmup_steps > 0
else math.ceil(num_training_steps * self.args.warmup_ratio)
)
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
self.optimizer,
num_warmup_steps=warmup_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
......
......@@ -14,6 +14,7 @@
import contextlib
import json
import math
import os
import warnings
from dataclasses import asdict, dataclass, field
......@@ -1065,6 +1066,17 @@ class TrainingArguments:
else:
yield
def get_warmup_steps(self, num_training_steps: int):
"""
Get number of steps used for a linear warmup.
"""
warmup_steps = (
self.warmup_steps
if self.warmup_steps > 0
else math.ceil(num_training_steps * self.warmup_ratio)
)
return warmup_steps
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment