Commit 75c1d866 authored by Yuxin Wu's avatar Yuxin Wu Committed by Facebook GitHub Bot
Browse files

Add reduce_param_groups to D2

Summary: this utility function was added in D30272112 (https://github.com/facebookresearch/d2go/commit/737d099b0a8b0fb1f548435e73f95e1252442827) and is useful to all D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8) users as well

Differential Revision: D31833523

fbshipit-source-id: 0adfc612adb8b448fa7f3dbec1b1278c309554c5
parent 1967e62a
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools import itertools
import logging import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
from d2go.utils.qat_utils import iterate_module_named_parameters from d2go.utils.qat_utils import iterate_module_named_parameters
from detectron2.solver.build import ( from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping, maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
reduce_param_groups,
) )
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
...@@ -58,49 +58,7 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg): ...@@ -58,49 +58,7 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
) )
params += model.get_optimizer_param_groups(cfg) params += model.get_optimizer_param_groups(cfg)
# Reorganize the parameter groups and merge duplicated groups return reduce_param_groups(params)
# The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we reorganize the
# parameter groups and merge duplicated groups. This approach speeds
# up optimizer step significantly.
params = expand_optimizer_param_groups(params)
params = regroup_optimizer_param_groups(params)
return params
def expand_optimizer_param_groups(params: List[Dict[str, Any]]):
"""Expand the optimizer parameter groups so that each group contains only
one parameter
"""
ret = defaultdict(dict)
for item in params:
assert "params" in item
cur_params = {x: y for x, y in item.items() if x != "params"}
for param in item["params"]:
ret[param]["params"] = [param]
ret[param].update(cur_params)
ret = list(ret.values())
return ret
def regroup_optimizer_param_groups(params: List[Dict[str, Any]]):
"""Regroup the optimizer parameter groups using the optimizer parameters as key"""
groups = defaultdict(list)
for item in params:
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
groups[cur_params] += item["params"]
ret = []
for param_keys, param_values in groups.items():
cur = {kv[0]: kv[1] for kv in param_keys}
cur["params"] = param_values
ret.append(cur)
return ret
def get_optimizer_param_groups_default(model: OptimizerModelsType): def get_optimizer_param_groups_default(model: OptimizerModelsType):
......
...@@ -10,10 +10,6 @@ import torch ...@@ -10,10 +10,6 @@ import torch
from d2go.optimizer import ( from d2go.optimizer import (
build_optimizer_mapper, build_optimizer_mapper,
) )
from d2go.optimizer.build import (
expand_optimizer_param_groups,
regroup_optimizer_param_groups,
)
from d2go.utils.testing import helper from d2go.utils.testing import helper
...@@ -94,67 +90,6 @@ def get_optimizer_cfg( ...@@ -94,67 +90,6 @@ def get_optimizer_cfg(
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
def test_expand_optimizer_param_groups(self):
groups = [
{
"params": ["p1", "p2", "p3", "p4"],
"lr": 1.0,
"weight_decay": 3.0,
},
{
"params": ["p2", "p3", "p5"],
"lr": 2.0,
"momentum": 2.0,
},
{
"params": ["p1"],
"weight_decay": 4.0,
},
]
gt_groups = [
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
]
out = expand_optimizer_param_groups(groups)
self.assertEqual(out, gt_groups)
def test_regroup_optimizer_param_groups(self):
expanded_groups = [
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
]
gt_groups = [
{
"lr": 1.0,
"weight_decay": 4.0,
"params": ["p1"],
},
{
"lr": 2.0,
"weight_decay": 3.0,
"momentum": 2.0,
"params": ["p2", "p3"],
},
{
"lr": 1.0,
"weight_decay": 3.0,
"params": ["p4"],
},
{
"lr": 2.0,
"momentum": 2.0,
"params": ["p5"],
},
]
out = regroup_optimizer_param_groups(expanded_groups)
self.assertEqual(out, gt_groups)
def test_create_optimizer_default(self): def test_create_optimizer_default(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment