"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0afa5071bd84e44301750fdc594e33db102cf374"
Unverified Commit 037bdf82 authored by Masatoshi TSUCHIYA's avatar Masatoshi TSUCHIYA Committed by GitHub
Browse files

Refer warmup_ratio when setting warmup_num_steps. (#12818)

* Refer warmup_ratio when setting warmup_num_steps.

* Add a method to get number of warmup steps to TrainerArguments class.

* Fix.

* Fix.
parent 15d19ecf
...@@ -204,7 +204,6 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -204,7 +204,6 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
self.fill_match("scheduler.params.warmup_num_steps", args.warmup_steps, "warmup_steps")
# total_num_steps - will get set in trainer_config_finalize # total_num_steps - will get set in trainer_config_finalize
# fp16 # fp16
...@@ -245,6 +244,7 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -245,6 +244,7 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
# scheduler # scheduler
self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)")
self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps")
if len(self.mismatches) > 0: if len(self.mismatches) > 0:
mismatches = "\n".join(self.mismatches) mismatches = "\n".join(self.mismatches)
......
...@@ -822,16 +822,10 @@ class Trainer: ...@@ -822,16 +822,10 @@ class Trainer:
num_training_steps (int): The number of training steps to do. num_training_steps (int): The number of training steps to do.
""" """
if self.lr_scheduler is None: if self.lr_scheduler is None:
warmup_steps = (
self.args.warmup_steps
if self.args.warmup_steps > 0
else math.ceil(num_training_steps * self.args.warmup_ratio)
)
self.lr_scheduler = get_scheduler( self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type, self.args.lr_scheduler_type,
self.optimizer, self.optimizer,
num_warmup_steps=warmup_steps, num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
) )
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import contextlib import contextlib
import json import json
import math
import os import os
import warnings import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
...@@ -1065,6 +1066,17 @@ class TrainingArguments: ...@@ -1065,6 +1066,17 @@ class TrainingArguments:
else: else:
yield yield
def get_warmup_steps(self, num_training_steps: int):
"""
Get number of steps used for a linear warmup.
"""
warmup_steps = (
self.warmup_steps
if self.warmup_steps > 0
else math.ceil(num_training_steps * self.warmup_ratio)
)
return warmup_steps
def to_dict(self): def to_dict(self):
""" """
Serializes this instance while replace `Enum` by their values (for JSON serialization support). Serializes this instance while replace `Enum` by their values (for JSON serialization support).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment