You need to sign in or sign up before continuing.
Commit 87ce583c authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Supported specifying customized parameter groups from model.

Summary:
Supported specifying customized parameter groups from model.
* Allow model to specify customized parameter groups by implementing a function `model.get_optimizer_param_groups(cfg)`
* Supported model with ddp.

Reviewed By: zhanghang1989

Differential Revision: D31289315

fbshipit-source-id: c91ba8014508e9fd5f172601b9c1c83c188338fd
parent 2dc3bc02
......@@ -47,6 +47,15 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
)
# parameter groups from model function `model.get_optimizer_param_groups(opts)`
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
if hasattr(model, "get_optimizer_param_groups"):
logger.info(
"Getting optimizer parameter groups from model.get_optimizer_param_groups()"
)
params += model.get_optimizer_param_groups(cfg)
# Reorganize the parameter groups and merge duplicated groups
# The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
......
......@@ -14,6 +14,7 @@ from d2go.optimizer.build import (
expand_optimizer_param_groups,
regroup_optimizer_param_groups,
)
from d2go.utils.testing import helper
class TestArch(torch.nn.Module):
......@@ -247,3 +248,75 @@ class TestOptimizer(unittest.TestCase):
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model"
cfg.SOLVER.OPTIMIZER = optimizer_name
_test_each_optimizer(cfg)
def test_create_optimizer_custom(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv(x))
def get_optimizer_param_groups(self, _opts):
ret = [
{
"params": [self.conv.weight],
"lr": 10.0,
}
]
return ret
model = Model()
cfg = get_optimizer_cfg(lr=1.0, weight_decay=1.0, weight_decay_norm=0.0)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 3)
_check_param_group(
self, optimizer.param_groups[0], num_params=1, lr=10.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[1], num_params=1, lr=1.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
)
@helper.enable_ddp_env
def test_create_optimizer_custom_ddp(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv(x))
def get_optimizer_param_groups(self, _opts):
ret = [
{
"params": [self.conv.weight],
"lr": 10.0,
}
]
return ret
model = Model()
model = torch.nn.parallel.DistributedDataParallel(model)
cfg = get_optimizer_cfg(lr=1.0, weight_decay=1.0, weight_decay_norm=0.0)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 3)
_check_param_group(
self, optimizer.param_groups[0], num_params=1, lr=10.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[1], num_params=1, lr=1.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment