"examples/community/vscode:/vscode.git/clone" did not exist on "27061426f7f67739e27abdbd92f9826e450c53d2"
Commit 2144cbae authored by Sze-qq's avatar Sze-qq Committed by Frank Lee
Browse files

[NFC] polish colossalai/nn/lr_scheduler/multistep.py code style (#1572)

parent e4bf7ae6
......@@ -22,7 +22,13 @@ class MultiStepLR(_MultiStepLR):
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
"""
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
def __init__(self,
optimizer,
total_steps: int,
milestones: List[int] = None,
gamma: float = 0.1,
last_epoch: int = -1,
**kwargs):
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
......@@ -41,12 +47,16 @@ class MultiStepWarmupLR(WarmupScheduler):
the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr.
"""
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
gamma: float = 0.1, last_epoch: int = -1, **kwargs):
def __init__(self,
optimizer,
total_steps: int,
warmup_steps: int = 0,
milestones: List[int] = None,
gamma: float = 0.1,
last_epoch: int = -1,
**kwargs):
if len(milestones) == 0:
raise ValueError('milestones cannot be empty')
milestones = [
v - warmup_steps for v in milestones if v >= warmup_steps]
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
gamma=gamma)
milestones = [v - warmup_steps for v in milestones if v >= warmup_steps]
base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment