Commit b9aa4855 authored by Valentin Andrei's avatar Valentin Andrei Committed by Facebook GitHub Bot
Browse files

Fix LR auto-scale for multi-tensor optimizers

Reviewed By: stephenyan1231, zhanghang1989

Differential Revision: D30903817

fbshipit-source-id: 578e6b02a1bd59b1bd841399fc60111d320ae9aa
parent 3fd2e635
...@@ -19,7 +19,7 @@ def reduce_param_groups(param_groups: List[Dict[str, Any]]): ...@@ -19,7 +19,7 @@ def reduce_param_groups(param_groups: List[Dict[str, Any]]):
# with the same lr and weight_decay in a single group. This approach speeds # with the same lr and weight_decay in a single group. This approach speeds
# up optimizer step significantly. # up optimizer step significantly.
dict_new_groups: Dict[str, Dict[str, Any]] = {} dict_new_groups: Dict[tuple, Dict[str, Any]] = {}
for param_group in param_groups: for param_group in param_groups:
# value is a list of parameters from the previous group # value is a list of parameters from the previous group
......
...@@ -114,6 +114,8 @@ def default_scale_d2_configs(cfg, new_world_size): ...@@ -114,6 +114,8 @@ def default_scale_d2_configs(cfg, new_world_size):
lr_scales = { lr_scales = {
"sgd": gpu_scale, "sgd": gpu_scale,
"adamw": 1, "adamw": 1,
"sgd_mt": gpu_scale,
"adamw_mt": 1,
} }
optim_name = cfg.SOLVER.OPTIMIZER.lower() optim_name = cfg.SOLVER.OPTIMIZER.lower()
lr_scale = lr_scales[optim_name] if optim_name in lr_scales else gpu_scale lr_scale = lr_scales[optim_name] if optim_name in lr_scales else gpu_scale
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
import unittest import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
...@@ -12,30 +13,49 @@ from d2go.optimizer import build_optimizer_mapper ...@@ -12,30 +13,49 @@ from d2go.optimizer import build_optimizer_mapper
class TestArch(torch.nn.Module): class TestArch(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = torch.nn.Conv2d(3, 4, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(3, 4, kernel_size=5, stride=1, padding=1)
self.bn = torch.nn.BatchNorm2d(4) self.bn = torch.nn.BatchNorm2d(4)
self.relu = torch.nn.ReLU(inplace=True) self.relu = torch.nn.ReLU(inplace=True)
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(4, 1)
def forward(self, x): def forward(self, x):
ret = self.conv(x) ret = self.conv(x)
ret = self.bn(ret) ret = self.bn(ret)
ret = self.relu(ret) ret = self.relu(ret)
ret = self.avgpool(ret) ret = self.avgpool(ret)
ret = torch.transpose(ret, 1, 3)
ret = self.linear(ret)
return ret return ret
def _test_each_optimizer(cfg): def _test_each_optimizer(cfg):
print("Solver: " + str(cfg.SOLVER.OPTIMIZER))
model = TestArch() model = TestArch()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = build_optimizer_mapper(cfg, model) optimizer = build_optimizer_mapper(cfg, model)
optimizer.zero_grad() optimizer.zero_grad()
for _ in range(10):
x = torch.rand(1, 3, 24, 24) random.seed(20210912)
y = model(x) for _ in range(2500):
loss = y.mean() target = torch.empty(1, 1, 1, 1).fill_(random.randint(0, 1))
x = torch.add(torch.rand(1, 3, 16, 16), 2 * target)
y_pred = model(x)
loss = criterion(y_pred, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
n_correct = 0
for _ in range(200):
target = torch.empty(1, 1, 1, 1).fill_(random.randint(0, 1))
x = torch.add(torch.rand(1, 3, 16, 16), 2 * target)
y_pred = torch.round(torch.sigmoid(model(x)))
if y_pred == target:
n_correct += 1
print("Correct prediction rate {0}.".format(n_correct / 200))
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
def test_all_optimizers(self): def test_all_optimizers(self):
...@@ -45,6 +65,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -45,6 +65,7 @@ class TestOptimizer(unittest.TestCase):
for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]: for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]:
for mult in multipliers: for mult in multipliers:
cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.OPTIMIZER = optimizer_name cfg.SOLVER.OPTIMIZER = optimizer_name
cfg.SOLVER.MULTIPLIERS = mult cfg.SOLVER.MULTIPLIERS = mult
_test_each_optimizer(cfg) _test_each_optimizer(cfg)
...@@ -54,6 +75,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -54,6 +75,7 @@ class TestOptimizer(unittest.TestCase):
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]: for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]:
cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.2 cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.2
cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model" cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment