optimizer_factory.py 13.8 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

7
8
# pyre-unsafe

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
9
import inspect
10
11
import logging
import os
12
13
14
from collections import defaultdict
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

import torch.optim

from accelerate import Accelerator

from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)

logger = logging.getLogger(__name__)


class OptimizerFactoryBase(ReplaceableBase):
    def __call__(
        self, model: ImplicitronModelBase, **kwargs
    ) -> Tuple[torch.optim.Optimizer, Any]:
        """
        Initialize the optimizer and lr scheduler.

        Args:
            model: The model with optionally loaded weights.

        Returns:
            An optimizer module (optionally loaded from a checkpoint) and
            a learning rate scheduler module (should be a subclass of torch.optim's
            lr_scheduler._LRScheduler).
        """
        raise NotImplementedError()


@registry.register
class ImplicitronOptimizerFactory(OptimizerFactoryBase):
    """
    A factory that initializes the optimizer and lr scheduler.

    Members:
        betas: Beta parameters for the Adam optimizer.
        breed: The type of optimizer to use. We currently support SGD, Adagrad
            and Adam.
        exponential_lr_step_size: With Exponential policy only,
            lr = lr * gamma ** (epoch/step_size)
        gamma: Multiplicative factor of learning rate decay.
        lr: The value for the initial learning rate.
        lr_policy: The policy to use for learning rate. We currently support
            MultiStepLR and Exponential policies.
        momentum: A momentum value (for SGD only).
        multistep_lr_milestones: With MultiStepLR policy only: list of
            increasing epoch indices at which the learning rate is modified.
        momentum: Momentum factor for SGD optimizer.
        weight_decay: The optimizer weight_decay (L2 penalty on model weights).
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
70
        foreach: Whether to use new "foreach" implementation of optimizer where
            available (e.g. requires PyTorch 1.12.0 for Adam)
71
72
73
74
75
76
        group_learning_rates: Parameters or modules can be assigned to parameter
            groups. This dictionary has names of those parameter groups as keys
            and learning rates as values. All parameter group names have to be
            defined in this dictionary. Parameters which do not have predefined
            parameter group are put into "default" parameter group which has
            `lr` as its learning rate.
77
78
79
80
81
82
83
84
85
86
87
    """

    betas: Tuple[float, ...] = (0.9, 0.999)
    breed: str = "Adam"
    exponential_lr_step_size: int = 250
    gamma: float = 0.1
    lr: float = 0.0005
    lr_policy: str = "MultiStepLR"
    momentum: float = 0.9
    multistep_lr_milestones: tuple = ()
    weight_decay: float = 0.0
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
88
89
    linear_exponential_lr_milestone: int = 200
    linear_exponential_start_gamma: float = 0.1
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
90
    foreach: Optional[bool] = True
91
    group_learning_rates: Dict[str, float] = field(default_factory=lambda: {})
92
93
94
95
96
97
98
99
100
101

    def __post_init__(self):
        run_auto_creation(self)

    def __call__(
        self,
        last_epoch: int,
        model: ImplicitronModelBase,
        accelerator: Optional[Accelerator] = None,
        exp_dir: Optional[str] = None,
102
103
        resume: bool = True,
        resume_epoch: int = -1,
104
105
106
107
108
109
110
111
112
113
114
        **kwargs,
    ) -> Tuple[torch.optim.Optimizer, Any]:
        """
        Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer.

        Args:
            last_epoch: If the model was loaded from checkpoint this will be the
                number of the last epoch that was saved.
            model: The model with optionally loaded weights.
            accelerator: An optional Accelerator instance.
            exp_dir: Root experiment directory.
115
116
117
118
            resume: If True, attempt to load optimizer checkpoint from exp_dir.
                Failure to do so will return a newly initialized optimizer.
            resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
                `resume_epoch` <= 0, then resume from the latest checkpoint.
119
120
121
122
123
124
125
126
127
        Returns:
            An optimizer module (optionally loaded from a checkpoint) and
            a learning rate scheduler module (should be a subclass of torch.optim's
            lr_scheduler._LRScheduler).
        """
        # Get the parameters to optimize
        if hasattr(model, "_get_param_groups"):  # use the model function
            p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
        else:
128
129
130
131
            p_groups = [
                {"params": params, "lr": self._get_group_learning_rate(group)}
                for group, params in self._get_param_groups(model).items()
            ]
132
133

        # Intialize the optimizer
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
134
135
136
137
        optimizer_kwargs: Dict[str, Any] = {
            "lr": self.lr,
            "weight_decay": self.weight_decay,
        }
138
        if self.breed == "SGD":
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
139
140
            optimizer_class = torch.optim.SGD
            optimizer_kwargs["momentum"] = self.momentum
141
        elif self.breed == "Adagrad":
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
142
            optimizer_class = torch.optim.Adagrad
143
        elif self.breed == "Adam":
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
144
145
            optimizer_class = torch.optim.Adam
            optimizer_kwargs["betas"] = self.betas
146
        else:
147
            raise ValueError(f"No such solver type {self.breed}")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
148
149
150
151

        if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
            optimizer_kwargs["foreach"] = self.foreach
        optimizer = optimizer_class(p_groups, **optimizer_kwargs)
152
        logger.info(f"Solver type = {self.breed}")
153
154

        # Load state from checkpoint
155
156
157
158
159
160
        optimizer_state = self._get_optimizer_state(
            exp_dir,
            accelerator,
            resume_epoch=resume_epoch,
            resume=resume,
        )
161
        if optimizer_state is not None:
162
            logger.info("Setting loaded optimizer state.")
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            optimizer.load_state_dict(optimizer_state)

        # Initialize the learning rate scheduler
        if self.lr_policy.casefold() == "MultiStepLR".casefold():
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=self.multistep_lr_milestones,
                gamma=self.gamma,
            )
        elif self.lr_policy.casefold() == "Exponential".casefold():
            scheduler = torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
                verbose=False,
            )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        elif self.lr_policy.casefold() == "LinearExponential".casefold():
            # linear learning rate progression between epochs 0 to
            # self.linear_exponential_lr_milestone, followed by exponential
            # lr decay for the rest of the epochs
            def _get_lr(epoch: int):
                m = self.linear_exponential_lr_milestone
                if epoch < m:
                    w = (m - epoch) / m
                    gamma = w * self.linear_exponential_start_gamma + (1 - w)
                else:
                    epoch_rest = epoch - m
                    gamma = self.gamma ** (epoch_rest / self.exponential_lr_step_size)
                return gamma

            scheduler = torch.optim.lr_scheduler.LambdaLR(
                optimizer, _get_lr, verbose=False
            )
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        else:
            raise ValueError("no such lr policy %s" % self.lr_policy)

        # When loading from checkpoint, this will make sure that the
        # lr is correctly set even after returning.
        for _ in range(last_epoch):
            scheduler.step()

        optimizer.zero_grad()

        return optimizer, scheduler

    def _get_optimizer_state(
        self,
        exp_dir: Optional[str],
        accelerator: Optional[Accelerator] = None,
211
212
        resume: bool = True,
        resume_epoch: int = -1,
213
214
215
    ) -> Optional[Dict[str, Any]]:
        """
        Load an optimizer state from a checkpoint.
216
217
218
219
220
221

        resume: If True, attempt to load the last checkpoint from `exp_dir`
            passed to __call__. Failure to do so will return a newly initialized
            optimizer.
        resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
            `resume_epoch` <= 0, then resume from the latest checkpoint.
222
        """
223
        if exp_dir is None or not resume:
224
            return None
225
226
227
228
229
230
        if resume_epoch > 0:
            save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
            if not os.path.isfile(save_path):
                raise FileNotFoundError(
                    f"Cannot find optimizer from epoch {resume_epoch}."
                )
231
232
233
234
        else:
            save_path = model_io.find_last_checkpoint(exp_dir)
        optimizer_state = None
        if save_path is not None:
235
            logger.info(f"Found previous optimizer state {save_path} -> resuming.")
236
237
238
239
240
241
242
243
244
245
            opt_path = model_io.get_optimizer_path(save_path)

            if os.path.isfile(opt_path):
                map_location = None
                if accelerator is not None and not accelerator.is_local_main_process:
                    map_location = {
                        "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
                    }
                optimizer_state = torch.load(opt_path, map_location)
            else:
246
                raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
247
        return optimizer_state
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    def _get_param_groups(
        self, module: torch.nn.Module
    ) -> Dict[str, List[torch.nn.Parameter]]:
        """
        Recursively visits all the modules inside the `module` and sorts all the
        parameters in parameter groups.

        Uses `param_groups` dictionary member, where keys are names of individual
        parameters or module members and values are the names of the parameter groups
        for those parameters or members. "self" key is used to denote the parameter groups
        at the module level. Possible keys, including the "self" key do not have to
        be defined. By default all parameters have the learning rate defined in the
        optimizer. This can be overridden by setting the parameter group in `param_groups`
262
263
264
265
266
267
        member of a specific module. Values are a parameter group name. The keys
        specify what parameters will be affected as follows:
            - “self”: All the parameters of the module and its child modules
            - name of a parameter: A parameter with that name.
            - name of a module member: All the parameters of the module and its
                child modules.
268
269
                This is useful if members do not have `param_groups`, for
                example torch.nn.Linear.
270
271
            - <name of module member>.<something>: recursive. Same as if <something>
                was used in param_groups of that submodule/member.
272
273
274
275
276
277
278
279
280
281

        Args:
            module: module from which to extract the parameters and their parameter
                groups
        Returns:
            dictionary with parameter groups as keys and lists of parameters as values
        """

        param_groups = defaultdict(list)

282
283
284
285
286
287
288
289
290
291
292
293
        def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
            """
            Visitor for module to assign its parameters to the relevant member of
            param_groups.

            Args:
                module: the module being visited in a depth-first search
                default_group: the param group to assign parameters to unless
                                otherwise overriden.
                mapping: known mappings of parameters to groups for this module,
                    destructively modified by this function.
            """
294
295
296
297
298
299
300
301
            # If key self is defined in param_groups then chenge the default param
            # group for all parameters and children in the module.
            if hasattr(module, "param_groups") and "self" in module.param_groups:
                default_group = module.param_groups["self"]

            # Collect all the parameters that are directly inside the `module`,
            # they will be in the default param group if they don't have
            # defined group.
302
303
304
            if hasattr(module, "param_groups"):
                mapping.update(module.param_groups)

305
306
            for name, param in module.named_parameters(recurse=False):
                if param.requires_grad:
307
                    group_name = mapping.get(name, default_group)
308
                    logger.debug(f"Assigning {name} to param_group {group_name}")
309
                    param_groups[group_name].append(param)
310
311
312
313

            # If children have defined default param group then use it else pass
            # own default.
            for child_name, child in module.named_children():
314
315
316
317
318
319
320
321
                mapping_to_add = {
                    name[len(child_name) + 1 :]: group
                    for name, group in mapping.items()
                    if name.startswith(child_name + ".")
                }
                traverse(child, mapping.get(child_name, default_group), mapping_to_add)

        traverse(module, "default", {})
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        return param_groups

    def _get_group_learning_rate(self, group_name: str) -> float:
        """
        Wraps the `group_learning_rates` dictionary providing errors and returns
        `self.lr` for "default" group_name.

        Args:
            group_name: a string representing the name of the group
        Returns:
            learning rate for a specific group
        """
        if group_name == "default":
            return self.lr
        lr = self.group_learning_rates.get(group_name, None)
        if lr is None:
            raise ValueError(f"no learning rate given for group {group_name}")
        return lr