Unverified Commit 7730a79f authored by Yawei Li's avatar Yawei Li Committed by GitHub
Browse files

fix typo of annealing (#433)

parent 6ece0e5d
...@@ -313,7 +313,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -313,7 +313,8 @@ class BaseRunner(metaclass=ABCMeta):
# If the type of policy is all in lower case, e.g., 'cyclic', # If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'. # then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater. # This is for the convenient usage of Lr updater.
# Since this is not applicable for `CosineAnealingLrUpdater`, # Since this is not applicable for `
# CosineAnnealingLrUpdater`,
# the string will not be changed if it contains capital letters. # the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower(): if policy_type == policy_type.lower():
policy_type = policy_type.title() policy_type = policy_type.title()
...@@ -333,7 +334,8 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -333,7 +334,8 @@ class BaseRunner(metaclass=ABCMeta):
# If the type of policy is all in lower case, e.g., 'cyclic', # If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'. # then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of momentum updater. # This is for the convenient usage of momentum updater.
# Since this is not applicable for `CosineAnealingMomentumUpdater`, # Since this is not applicable for
# `CosineAnnealingMomentumUpdater`,
# the string will not be changed if it contains capital letters. # the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower(): if policy_type == policy_type.lower():
policy_type = policy_type.title() policy_type = policy_type.title()
......
...@@ -224,13 +224,13 @@ class InvLrUpdaterHook(LrUpdaterHook): ...@@ -224,13 +224,13 @@ class InvLrUpdaterHook(LrUpdaterHook):
@HOOKS.register_module() @HOOKS.register_module()
class CosineAnealingLrUpdaterHook(LrUpdaterHook): class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs): def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
assert (min_lr is None) ^ (min_lr_ratio is None) assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio self.min_lr_ratio = min_lr_ratio
super(CosineAnealingLrUpdaterHook, self).__init__(**kwargs) super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr): def get_lr(self, runner, base_lr):
if self.by_epoch: if self.by_epoch:
......
...@@ -103,13 +103,13 @@ class MomentumUpdaterHook(Hook): ...@@ -103,13 +103,13 @@ class MomentumUpdaterHook(Hook):
@HOOKS.register_module() @HOOKS.register_module()
class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook): class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
assert (min_momentum is None) ^ (min_momentum_ratio is None) assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio self.min_momentum_ratio = min_momentum_ratio
super(CosineAnealingMomentumUpdaterHook, self).__init__(**kwargs) super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
def get_momentum(self, runner, base_momentum): def get_momentum(self, runner, base_momentum):
if self.by_epoch: if self.by_epoch:
......
...@@ -18,11 +18,11 @@ from torch.utils.data import DataLoader ...@@ -18,11 +18,11 @@ from torch.utils.data import DataLoader
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook, from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook) PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnealingLrUpdaterHook, from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook) CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import ( from mmcv.runner.hooks.momentum_updater import (
CosineAnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook) CosineAnnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
def test_pavi_hook(): def test_pavi_hook():
...@@ -101,7 +101,7 @@ def test_cosine_runner_hook(): ...@@ -101,7 +101,7 @@ def test_cosine_runner_hook():
runner = _build_demo_runner() runner = _build_demo_runner()
# add momentum scheduler # add momentum scheduler
hook = CosineAnealingMomentumUpdaterHook( hook = CosineAnnealingMomentumUpdaterHook(
min_momentum_ratio=0.99 / 0.95, min_momentum_ratio=0.99 / 0.95,
by_epoch=False, by_epoch=False,
warmup_iters=2, warmup_iters=2,
...@@ -109,7 +109,7 @@ def test_cosine_runner_hook(): ...@@ -109,7 +109,7 @@ def test_cosine_runner_hook():
runner.register_hook(hook) runner.register_hook(hook)
# add momentum LR scheduler # add momentum LR scheduler
hook = CosineAnealingLrUpdaterHook( hook = CosineAnnealingLrUpdaterHook(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9) by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
runner.register_hook(hook) runner.register_hook(hook)
runner.register_hook(IterTimerHook()) runner.register_hook(IterTimerHook())
......
...@@ -150,7 +150,7 @@ def test_build_lr_momentum_hook(): ...@@ -150,7 +150,7 @@ def test_build_lr_momentum_hook():
# test policy that is already title # test policy that is already title
lr_config = dict( lr_config = dict(
policy='CosineAnealing', policy='CosineAnnealing',
by_epoch=False, by_epoch=False,
min_lr_ratio=0, min_lr_ratio=0,
warmup_iters=2, warmup_iters=2,
...@@ -200,7 +200,7 @@ def test_build_lr_momentum_hook(): ...@@ -200,7 +200,7 @@ def test_build_lr_momentum_hook():
# test policy that is already title # test policy that is already title
mom_config = dict( mom_config = dict(
policy='CosineAnealing', policy='CosineAnnealing',
min_momentum_ratio=0.99 / 0.95, min_momentum_ratio=0.99 / 0.95,
by_epoch=False, by_epoch=False,
warmup_iters=2, warmup_iters=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment