build.py 5.76 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections import defaultdict
from typing import Any, Dict, List

import oneflow as flow

from libai.config import instantiate
from libai.layers import LayerNorm

# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/solver/build.py
# --------------------------------------------------------


def build_optimizer(cfg, model):
    """
    Build an optimizer from config.
    """

    cfg.params.model = model
    optim = instantiate(cfg)
    return optim


def get_default_optimizer_params(
    model,
    base_lr=None,
    weight_decay=None,
    weight_decay_norm=None,
    weight_decay_bias=None,
    clip_grad_max_norm=None,
    clip_grad_norm_type=None,
    overrides=None,
):
    """
    Get default param list for optimizer, with suport for a few types of overrides.
    If no overrides are needed, it is equivalent to `model.parameters()`.

    Arguments:
        base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
        weight_decay: weight decay for every group by default. Can be omitted to use the one
            in optimizer.
        weight_decay_norm: override weight decay for params in normalization layers
        weight_decay_bias: override weight decay for bias parameters
        overrides: if not `None`, provides values for optimizer hyperparameters
            (LR, weight decay) for module parameters with a given name; e.g.
            ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
            weight decay values for all module parameters named `embedding`.

    For common transformer models, ``weight_decay_norm`` and ``weight_decay_bias``
    are usually set to 0.

    Example:
    ::

        flow.optim.AdamW(
            get_default_optimizer_params(model, weight_decay_norm=0, weight_decay_bias=0),
            lr=0.01,
            weight_decay=1e-4
        )
    """
    if overrides is None:
        overrides = {}
    defaults = {}
    if base_lr is not None:
        defaults["lr"] = base_lr
    if weight_decay is not None:
        defaults["weight_decay"] = weight_decay
    if clip_grad_max_norm is not None and clip_grad_norm_type is not None:
        defaults["clip_grad_max_norm"] = clip_grad_max_norm
        defaults["clip_grad_norm_type"] = clip_grad_norm_type
    bias_overrides = {}
    if weight_decay_bias is not None:
        bias_overrides["weight_decay"] = weight_decay_bias
    if len(bias_overrides):
        if "bias" in overrides:
            raise ValueError("Conflicting overrides for 'bias'")
        overrides["bias"] = bias_overrides

    norm_module_types = (
        LayerNorm,
        flow.nn.BatchNorm1d,
        flow.nn.BatchNorm2d,
        flow.nn.BatchNorm3d,
        flow.nn.GroupNorm,
        flow.nn.InstanceNorm1d,
        flow.nn.InstanceNorm2d,
        flow.nn.InstanceNorm3d,
        flow.nn.FusedBatchNorm1d,
        flow.nn.FusedBatchNorm2d,
        flow.nn.FusedBatchNorm3d,
    )
    params = []
    memo = set()
    for module in model.modules():
        for model_param_name, value in module.named_parameters(recurse=False):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)

            hyperparams = copy.copy(defaults)
            if isinstance(module, norm_module_types) and weight_decay_norm is not None:
                hyperparams["weight_decay"] = weight_decay_norm
            hyperparams.update(overrides.get(model_param_name, {}))
            params.append({"params": [value], **hyperparams})
    return reduce_param_groups(params)


def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Transform parameter groups into per-parameter structure.
    Later items in `params` can overwrite parameters set in previous items.
    """
    ret = defaultdict(dict)
    for item in params:
        assert "params" in item
        cur_params = {x: y for x, y in item.items() if x != "params"}
        for param in item["params"]:
            ret[param].update({"params": [param], **cur_params})
    return list(ret.values())


def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Reorganize the parameter groups and merge duplicated groups.
    The number of parameter groups needs to be as small as possible in order
    to efficiently use the OneFlow multi-tensor optimizer. Therefore instead
    of using a parameter_group per single parameter, we reorganize the
    parameter groups and merge duplicated groups. This approach speeds
    up multi-tensor optimizer significantly.
    """
    params = _expand_param_groups(params)
    groups = defaultdict(list)  # re-group all parameter groups by their hyperparams
    for item in params:
        cur_params = tuple((x, y) for x, y in item.items() if x != "params")
        groups[cur_params].extend(item["params"])
    ret = []
    for param_keys, param_values in groups.items():
        cur = {kv[0]: kv[1] for kv in param_keys}
        cur["params"] = param_values
        ret.append(cur)
    return ret