Commit 75c1d866 authored by Yuxin Wu's avatar Yuxin Wu Committed by Facebook GitHub Bot
Browse files

Add reduce_param_groups to D2

Summary: this utility function was added in D30272112 (https://github.com/facebookresearch/d2go/commit/737d099b0a8b0fb1f548435e73f95e1252442827) and is useful to all D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8) users as well

Differential Revision: D31833523

fbshipit-source-id: 0adfc612adb8b448fa7f3dbec1b1278c309554c5
parent 1967e62a
......@@ -2,13 +2,13 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import torch
from d2go.utils.qat_utils import iterate_module_named_parameters
from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
reduce_param_groups,
)
from detectron2.utils.registry import Registry
......@@ -58,49 +58,7 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
)
params += model.get_optimizer_param_groups(cfg)
# Reorganize the parameter groups and merge duplicated groups
# The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we reorganize the
# parameter groups and merge duplicated groups. This approach speeds
# up optimizer step significantly.
params = expand_optimizer_param_groups(params)
params = regroup_optimizer_param_groups(params)
return params
def expand_optimizer_param_groups(params: List[Dict[str, Any]]):
"""Expand the optimizer parameter groups so that each group contains only
one parameter
"""
ret = defaultdict(dict)
for item in params:
assert "params" in item
cur_params = {x: y for x, y in item.items() if x != "params"}
for param in item["params"]:
ret[param]["params"] = [param]
ret[param].update(cur_params)
ret = list(ret.values())
return ret
def regroup_optimizer_param_groups(params: List[Dict[str, Any]]):
"""Regroup the optimizer parameter groups using the optimizer parameters as key"""
groups = defaultdict(list)
for item in params:
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
groups[cur_params] += item["params"]
ret = []
for param_keys, param_values in groups.items():
cur = {kv[0]: kv[1] for kv in param_keys}
cur["params"] = param_values
ret.append(cur)
return ret
return reduce_param_groups(params)
def get_optimizer_param_groups_default(model: OptimizerModelsType):
......
......@@ -10,10 +10,6 @@ import torch
from d2go.optimizer import (
build_optimizer_mapper,
)
from d2go.optimizer.build import (
expand_optimizer_param_groups,
regroup_optimizer_param_groups,
)
from d2go.utils.testing import helper
......@@ -94,67 +90,6 @@ def get_optimizer_cfg(
class TestOptimizer(unittest.TestCase):
def test_expand_optimizer_param_groups(self):
groups = [
{
"params": ["p1", "p2", "p3", "p4"],
"lr": 1.0,
"weight_decay": 3.0,
},
{
"params": ["p2", "p3", "p5"],
"lr": 2.0,
"momentum": 2.0,
},
{
"params": ["p1"],
"weight_decay": 4.0,
},
]
gt_groups = [
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
]
out = expand_optimizer_param_groups(groups)
self.assertEqual(out, gt_groups)
def test_regroup_optimizer_param_groups(self):
expanded_groups = [
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
]
gt_groups = [
{
"lr": 1.0,
"weight_decay": 4.0,
"params": ["p1"],
},
{
"lr": 2.0,
"weight_decay": 3.0,
"momentum": 2.0,
"params": ["p2", "p3"],
},
{
"lr": 1.0,
"weight_decay": 3.0,
"params": ["p4"],
},
{
"lr": 2.0,
"momentum": 2.0,
"params": ["p5"],
},
]
out = regroup_optimizer_param_groups(expanded_groups)
self.assertEqual(out, gt_groups)
def test_create_optimizer_default(self):
class Model(torch.nn.Module):
def __init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment