Commit 2dc3bc02 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Refactor for get_optimizer_param_groups.

Summary:
Refactor for get_optimizer_param_groups.
* Split `get_default_optimizer_params()` into multiple functions:
  * `get_optimizer_param_groups_default()`
  * `get_optimizer_param_groups_lr()`
  * `get_optimizer_param_groups_weight_decay()`
* Regroup the parameters to create the minimal amount of groups.
* Print all parameter groups when the optimizer is created.
    Param group 0: {amsgrad: False, betas: (0.9, 0.999), eps: 1e-08, lr: 10.0, params: 1, weight_decay: 1.0}
    Param group 1: {amsgrad: False, betas: (0.9, 0.999), eps: 1e-08, lr: 1.0, params: 1, weight_decay: 1.0}
    Param group 2: {amsgrad: False, betas: (0.9, 0.999), eps: 1e-08, lr: 1.0, params: 2, weight_decay: 0.0}
* Add some unit tests.

Reviewed By: zhanghang1989

Differential Revision: D31287783

fbshipit-source-id: e87df0ae0e67343bb2130db945d8faced44d7411
parent 46f16a5e
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools import itertools
from typing import Any, Dict, List, Optional, Set import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import torch import torch
from detectron2.solver.build import ( from detectron2.solver.build import (
...@@ -11,77 +13,180 @@ from detectron2.utils.registry import Registry ...@@ -11,77 +13,180 @@ from detectron2.utils.registry import Registry
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER") D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
logger = logging.getLogger(__name__)
def reduce_param_groups(param_groups: List[Dict[str, Any]]):
OptimizerModelsType = Union[torch.nn.Module, torch.nn.parallel.DistributedDataParallel]
def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
"""
Get override optimizer parameter groups
* Get all default parameters
# Get parameter groups for normalization and bias
# Get parameter groups from model if the model implements `get_optimizer_param_groups()`
Parameters appear later will override parameters appear earlier
"""
# get all parameters that requires gradient
params = get_optimizer_param_groups_default(model)
# parameter groups for lr
params += get_optimizer_param_groups_lr(
model,
base_lr=cfg.SOLVER.BASE_LR,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
# parameter groups for normalization, bias, and embedding
params += get_optimizer_param_groups_weight_decay(
model,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
)
# Reorganize the parameter groups and merge duplicated groups
# The number of parameter groups needs to be as small as possible in order # The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we group all the params # of using a parameter_group per single parameter, we reorganize the
# with the same lr and weight_decay in a single group. This approach speeds # parameter groups and merge duplicated groups. This approach speeds
# up optimizer step significantly. # up optimizer step significantly.
params = expand_optimizer_param_groups(params)
params = regroup_optimizer_param_groups(params)
return params
def expand_optimizer_param_groups(params: List[Dict[str, Any]]):
"""Expand the optimizer parameter groups so that each group contains only
one parameter
"""
ret = defaultdict(dict)
for item in params:
assert "params" in item
cur_params = {x: y for x, y in item.items() if x != "params"}
for param in item["params"]:
ret[param]["params"] = [param]
ret[param].update(cur_params)
ret = list(ret.values())
return ret
dict_new_groups: Dict[tuple, Dict[str, Any]] = {} def regroup_optimizer_param_groups(params: List[Dict[str, Any]]):
"""Regroup the optimizer parameter groups using the optimizer parameters as key"""
groups = defaultdict(list)
for item in params:
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
groups[cur_params] += item["params"]
ret = []
for param_keys, param_values in groups.items():
cur = {kv[0]: kv[1] for kv in param_keys}
cur["params"] = param_values
ret.append(cur)
return ret
def iterate_module_named_parameters(
model: OptimizerModelsType, check_requires_grad=True
):
"""Iterate over all parameters for the model"""
memo = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if check_requires_grad and not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
for param_group in param_groups: yield module_name, module, module_param_name, value
# value is a list of parameters from the previous group
value = param_group["params"]
# lr and weight_decay are floating point values
lr = param_group["lr"]
weight_decay = param_group["weight_decay"]
# Create the new groups using combinations of lr and weight_decay def get_optimizer_param_groups_default(model: OptimizerModelsType):
group_key = (lr, weight_decay) ret = [
if group_key not in dict_new_groups: {
dict_new_groups[group_key] = { "params": list(
"params": value, filter(
"lr": lr, lambda x: x.requires_grad,
"weight_decay": weight_decay, model.parameters(),
)
)
} }
else: ]
# Add elements from an existing group to the new larger group return ret
dict_new_groups[group_key]["params"].extend(value)
return list(dict_new_groups.values()) def get_optimizer_param_groups_lr(
model: OptimizerModelsType,
base_lr: float,
def get_default_optimizer_params( bias_lr_factor: float = 1.0,
model: torch.nn.Module,
base_lr,
weight_decay,
weight_decay_norm,
weight_decay_embed,
bias_lr_factor=1.0,
weight_decay_bias=None,
use_param_group_reduction=False,
overrides: Optional[Dict[str, Dict[str, float]]] = None,
lr_multipliers_overwrite: Optional[Dict[str, float]] = None, lr_multipliers_overwrite: Optional[Dict[str, float]] = None,
): ):
""" """
Get default param list for optimizer Allow setting up lr for modules
Args: base_lr: lr for all modules
overrides (dict: str -> (dict: str -> float)): bias_lr_factor: scale factor for lr for bias term
if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
{"embedding": {"lr": 0.01, "weight_decay": 0.1}} will set the LR and
weight decay values for all module parameters named `embedding` (default: None)
lr_multipliers_overwrite (dict: str-> float): lr_multipliers_overwrite (dict: str-> float):
Applying different lr multiplier to a set of parameters whose names Applying different lr multiplier to a set of parameters whose names
containing certain keys. For example, if lr_multipliers_overwrite={'backbone': 0.1}, containing certain keys. For example, if lr_multipliers_overwrite={'backbone': 0.1},
the LR for the parameters whose names containing 'backbone' will be scaled to 0.1x. the LR for the parameters whose names containing 'backbone' will be scaled to 0.1x.
Set lr_multipliers_overwrite={} if no multipliers required. Set lr_multipliers_overwrite=None if no multipliers required.
use_param_group_reduction: """
if set to `False` we will have a parameter group for each parameter which makes params: List[Dict[str, Any]] = []
the optimizer very slow. This option should be used when using checkpoints of models for (
that were created using a parameter group for each param. module_name,
_module,
module_param_name,
value,
) in iterate_module_named_parameters(model):
cur_lr = base_lr
if module_param_name == "bias":
cur_lr = base_lr * bias_lr_factor
if lr_multipliers_overwrite is not None:
for kname, mult in lr_multipliers_overwrite.items():
if kname in module_name:
# apply multiplier for the params containing kname, e.g. backbone
cur_lr = cur_lr * mult
params += [
{
"params": [value],
"lr": cur_lr,
}
]
return params
def get_optimizer_param_groups_weight_decay(
model: OptimizerModelsType,
weight_decay: Optional[float],
weight_decay_norm: Optional[float] = None,
weight_decay_bias: Optional[float] = None,
weight_decay_embed: Optional[float] = None,
):
""" """
Allow setting up weight decay for normalization, embedding and bias
"""
if weight_decay_norm is None:
weight_decay_norm = weight_decay
if weight_decay_bias is None: if weight_decay_bias is None:
weight_decay_bias = weight_decay weight_decay_bias = weight_decay
if weight_decay_embed is None:
weight_decay_embed = weight_decay
norm_module_types = ( norm_module_types = (
torch.nn.BatchNorm1d, torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d, torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d, torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm, torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm, torch.nn.GroupNorm,
torch.nn.InstanceNorm1d, torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d, torch.nn.InstanceNorm2d,
...@@ -90,49 +195,58 @@ def get_default_optimizer_params( ...@@ -90,49 +195,58 @@ def get_default_optimizer_params(
torch.nn.LocalResponseNorm, torch.nn.LocalResponseNorm,
) )
params: List[Dict[str, Any]] = [] params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set() for (
for module_name, module in model.named_modules(): _module_name,
for module_param_name, value in module.named_parameters(recurse=False): module,
if not value.requires_grad: module_param_name,
continue value,
# Avoid duplicating parameters ) in iterate_module_named_parameters(model):
if value in memo: cur_wd = weight_decay
continue
memo.add(value)
schedule_params = {
"lr": base_lr,
"weight_decay": weight_decay,
}
if isinstance(module, norm_module_types): if isinstance(module, norm_module_types):
schedule_params["weight_decay"] = weight_decay_norm cur_wd = weight_decay_norm
elif isinstance(module, torch.nn.Embedding):
cur_wd = weight_decay_embed
elif module_param_name == "bias": elif module_param_name == "bias":
# NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0 cur_wd = weight_decay_bias
# and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer if cur_wd is not None:
# hyperparameters are by default exactly the same as for regular
# weights.
schedule_params["lr"] = base_lr * bias_lr_factor
schedule_params["weight_decay"] = weight_decay_bias
if isinstance(module, torch.nn.Embedding):
schedule_params["weight_decay"] = weight_decay_embed
if overrides is not None and module_param_name in overrides:
schedule_params.update(overrides[module_param_name])
if lr_multipliers_overwrite is not None:
for kname, mult in lr_multipliers_overwrite.items():
if kname in module_name:
# apply multiplier for the params containing kname, e.g. backbone
schedule_params["lr"] = schedule_params["lr"] * mult
params += [ params += [
{ {
"params": [value], "params": [value],
"lr": schedule_params["lr"], "weight_decay": cur_wd,
"weight_decay": schedule_params["weight_decay"],
} }
] ]
if use_param_group_reduction: return params
# Reduce number of param groups to speed-up optimizer step
return reduce_param_groups(params)
def get_optimizer_param_groups_override(
model: OptimizerModelsType,
overrides: Optional[Dict[str, Dict[str, float]]] = None,
):
"""
Allow setting up overrides for parameter groups
overrides (dict: str -> (dict: str -> float)):
if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
{"embedding": {"lr": 0.01, "weight_decay": 0.1}} will set the LR and
weight decay values for all module parameters named `embedding` (default: None)
"""
params: List[Dict[str, Any]] = []
if overrides is None:
return params
for (
_module_name,
_module,
module_param_name,
value,
) in iterate_module_named_parameters(model):
schedule_params = {}
if module_param_name in overrides:
schedule_params.update(overrides[module_param_name])
params += [{"params": [value], **schedule_params}]
return params return params
...@@ -170,16 +284,7 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -170,16 +284,7 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
Build an optimizer from config. Build an optimizer from config.
""" """
params = get_default_optimizer_params( params = get_optimizer_param_groups(model, cfg)
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)( return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params, params,
cfg.SOLVER.BASE_LR, cfg.SOLVER.BASE_LR,
...@@ -193,16 +298,7 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -193,16 +298,7 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
Build an optimizer from config. Build an optimizer from config.
""" """
params = get_default_optimizer_params( params = get_optimizer_param_groups(model, cfg)
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)( return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR params, cfg.SOLVER.BASE_LR
) )
...@@ -216,17 +312,7 @@ def sgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -216,17 +312,7 @@ def sgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
optimizer by end of H1'21. To benefit from the speedup, the number optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`. of parameter groups needs to be reduced using `reduce_param_groups`.
""" """
params = get_default_optimizer_params( params = get_optimizer_param_groups(model, cfg)
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
use_param_group_reduction=True,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.SGD)( return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.SGD)(
params, params,
cfg.SOLVER.BASE_LR, cfg.SOLVER.BASE_LR,
...@@ -243,17 +329,7 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -243,17 +329,7 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
optimizer by end of H1'21. To benefit from the speedup, the number optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`. of parameter groups needs to be reduced using `reduce_param_groups`.
""" """
params = get_default_optimizer_params( params = get_optimizer_param_groups(model, cfg)
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
use_param_group_reduction=True,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.AdamW)( return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.AdamW)(
params, cfg.SOLVER.BASE_LR params, cfg.SOLVER.BASE_LR
) )
...@@ -261,4 +337,23 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -261,4 +337,23 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
def build_optimizer_mapper(cfg, model): def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER name = cfg.SOLVER.OPTIMIZER
return D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model) optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
def _param_group_str(group):
ret = {x: y if x != "params" else len(y) for x, y in group.items()}
ret = sorted(ret.items())
ret = [f"{x[0]}: {x[1]}" for x in ret]
ret = "{" + ", ".join(ret) + "}"
return ret
def _param_groups_str(groups):
ret = ""
for idx, group in enumerate(groups):
ret += f"Param group {idx}: {_param_group_str(group)}\n"
return ret
logger.info(
f"optimizer parameter groups:\n{_param_groups_str(optimizer.param_groups)}"
)
return optimizer
...@@ -7,7 +7,13 @@ import unittest ...@@ -7,7 +7,13 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import (
build_optimizer_mapper,
)
from d2go.optimizer.build import (
expand_optimizer_param_groups,
regroup_optimizer_param_groups,
)
class TestArch(torch.nn.Module): class TestArch(torch.nn.Module):
...@@ -57,7 +63,167 @@ def _test_each_optimizer(cfg): ...@@ -57,7 +63,167 @@ def _test_each_optimizer(cfg):
print("Correct prediction rate {0}.".format(n_correct / 200)) print("Correct prediction rate {0}.".format(n_correct / 200))
def _check_param_group(self, group, num_params=None, **kwargs):
if num_params is not None:
self.assertEqual(len(group["params"]), num_params)
for key, val in kwargs.items():
self.assertEqual(group[key], val)
def get_optimizer_cfg(
lr,
weight_decay=None,
weight_decay_norm=None,
weight_decay_bias=None,
lr_mult=None,
):
runner = default_runner.Detectron2GoRunner()
cfg = runner.get_default_cfg()
if lr is not None:
cfg.SOLVER.BASE_LR = lr
if weight_decay is not None:
cfg.SOLVER.WEIGHT_DECAY = weight_decay
if weight_decay_norm is not None:
cfg.SOLVER.WEIGHT_DECAY_NORM = weight_decay_norm
if weight_decay_bias is not None:
cfg.SOLVER.WEIGHT_DECAY_BIAS = weight_decay_bias
if lr_mult is not None:
cfg.SOLVER.LR_MULTIPLIER_OVERWRITE = [lr_mult]
return cfg
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
def test_expand_optimizer_param_groups(self):
groups = [
{
"params": ["p1", "p2", "p3", "p4"],
"lr": 1.0,
"weight_decay": 3.0,
},
{
"params": ["p2", "p3", "p5"],
"lr": 2.0,
"momentum": 2.0,
},
{
"params": ["p1"],
"weight_decay": 4.0,
},
]
gt_groups = [
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
]
out = expand_optimizer_param_groups(groups)
self.assertEqual(out, gt_groups)
def test_regroup_optimizer_param_groups(self):
expanded_groups = [
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
]
gt_groups = [
{
"lr": 1.0,
"weight_decay": 4.0,
"params": ["p1"],
},
{
"lr": 2.0,
"weight_decay": 3.0,
"momentum": 2.0,
"params": ["p2", "p3"],
},
{
"lr": 1.0,
"weight_decay": 3.0,
"params": ["p4"],
},
{
"lr": 2.0,
"momentum": 2.0,
"params": ["p5"],
},
]
out = regroup_optimizer_param_groups(expanded_groups)
self.assertEqual(out, gt_groups)
def test_create_optimizer_default(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv(x))
model = Model()
cfg = get_optimizer_cfg(
lr=1.0, weight_decay=1.0, weight_decay_norm=1.0, weight_decay_bias=1.0
)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 1)
_check_param_group(
self, optimizer.param_groups[0], num_params=4, weight_decay=1.0, lr=1.0
)
def test_create_optimizer_lr(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 1)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv2(self.conv1(x)))
model = Model()
cfg = get_optimizer_cfg(
lr=1.0,
lr_mult={"conv1": 3.0, "conv2": 3.0},
weight_decay=2.0,
weight_decay_norm=2.0,
)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 2)
_check_param_group(self, optimizer.param_groups[0], num_params=4, lr=3.0)
_check_param_group(self, optimizer.param_groups[1], num_params=2, lr=1.0)
def test_create_optimizer_weight_decay_norm(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(self.conv(x))
model = Model()
cfg = get_optimizer_cfg(
lr=1.0, weight_decay=1.0, weight_decay_norm=2.0, weight_decay_bias=1.0
)
optimizer = build_optimizer_mapper(cfg, model)
self.assertEqual(len(optimizer.param_groups), 2)
_check_param_group(
self, optimizer.param_groups[0], num_params=2, lr=1.0, weight_decay=1.0
)
_check_param_group(
self, optimizer.param_groups[1], num_params=2, lr=1.0, weight_decay=2.0
)
def test_all_optimizers(self): def test_all_optimizers(self):
runner = default_runner.Detectron2GoRunner() runner = default_runner.Detectron2GoRunner()
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment