Commit 87ce583c authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Supported specifying customized parameter groups from model.

Summary:
Supported specifying customized parameter groups from model.
* Allow model to specify customized parameter groups by implementing a function `model.get_optimizer_param_groups(cfg)`
* Supported model with ddp.

Reviewed By: zhanghang1989

Differential Revision: D31289315

fbshipit-source-id: c91ba8014508e9fd5f172601b9c1c83c188338fd
parent 2dc3bc02
...@@ -47,6 +47,15 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg): ...@@ -47,6 +47,15 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED, weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
) )
# parameter groups from model function `model.get_optimizer_param_groups(opts)`
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
if hasattr(model, "get_optimizer_param_groups"):
logger.info(
"Getting optimizer parameter groups from model.get_optimizer_param_groups()"
)
params += model.get_optimizer_param_groups(cfg)
# Reorganize the parameter groups and merge duplicated groups # Reorganize the parameter groups and merge duplicated groups
# The number of parameter groups needs to be as small as possible in order # The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
......
...@@ -14,6 +14,7 @@ from d2go.optimizer.build import ( ...@@ -14,6 +14,7 @@ from d2go.optimizer.build import (
expand_optimizer_param_groups, expand_optimizer_param_groups,
regroup_optimizer_param_groups, regroup_optimizer_param_groups,
) )
from d2go.utils.testing import helper
class TestArch(torch.nn.Module): class TestArch(torch.nn.Module):
...@@ -247,3 +248,75 @@ class TestOptimizer(unittest.TestCase): ...@@ -247,3 +248,75 @@ class TestOptimizer(unittest.TestCase):
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model" cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model"
cfg.SOLVER.OPTIMIZER = optimizer_name cfg.SOLVER.OPTIMIZER = optimizer_name
_test_each_optimizer(cfg) _test_each_optimizer(cfg)
def test_create_optimizer_custom(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv(x))
def get_optimizer_param_groups(self, _opts):
ret = [
{
"params": [self.conv.weight],
"lr": 10.0,
}
]
return ret
model = Model()
cfg = get_optimizer_cfg(lr=1.0, weight_decay=1.0, weight_decay_norm=0.0)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 3)
_check_param_group(
self, optimizer.param_groups[0], num_params=1, lr=10.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[1], num_params=1, lr=1.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
)
@helper.enable_ddp_env
def test_create_optimizer_custom_ddp(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv(x))
def get_optimizer_param_groups(self, _opts):
ret = [
{
"params": [self.conv.weight],
"lr": 10.0,
}
]
return ret
model = Model()
model = torch.nn.parallel.DistributedDataParallel(model)
cfg = get_optimizer_cfg(lr=1.0, weight_decay=1.0, weight_decay_norm=0.0)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 3)
_check_param_group(
self, optimizer.param_groups[0], num_params=1, lr=10.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[1], num_params=1, lr=1.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment