model_factory.py 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
9
from typing import Optional
10
11
12
13
14

import torch.optim

from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
15
from pytorch3d.implicitron.tools import model_io
16
17
18
19
20
21
22
23
24
25
26
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats

logger = logging.getLogger(__name__)


class ModelFactoryBase(ReplaceableBase):
27
28
29

    resume: bool = True  # resume from the last checkpoint

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def __call__(self, **kwargs) -> ImplicitronModelBase:
        """
        Initialize the model (possibly from a previously saved state).

        Returns: An instance of ImplicitronModelBase.
        """
        raise NotImplementedError()

    def load_stats(self, **kwargs) -> Stats:
        """
        Initialize or load a Stats object.
        """
        raise NotImplementedError()


@registry.register
class ImplicitronModelFactory(ModelFactoryBase):  # pyre-ignore [13]
    """
    A factory class that initializes an implicit rendering model.

    Members:
        model: An ImplicitronModelBase object.
        resume: If True, attempt to load the last checkpoint from `exp_dir`
            passed to __call__. Failure to do so will return a model with ini-
54
            tial weights unless `force_resume` is True.
55
56
        resume_epoch: If `resume` is True: Resume a model at this epoch, or if
            `resume_epoch` <= 0, then resume from the latest checkpoint.
57
58
59
        force_resume: If True, throw a FileNotFoundError if `resume` is True but
            a model checkpoint cannot be found.

60
61
62
63
    """

    model: ImplicitronModelBase
    model_class_type: str = "GenericModel"
64
    resume: bool = True
65
    resume_epoch: int = -1
66
    force_resume: bool = False
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def __post_init__(self):
        run_auto_creation(self)

    def __call__(
        self,
        exp_dir: str,
        accelerator: Optional[Accelerator] = None,
    ) -> ImplicitronModelBase:
        """
        Returns an instance of `ImplicitronModelBase`, possibly loaded from a
        checkpoint (if self.resume, self.resume_epoch specify so).

        Args:
            exp_dir: Root experiment directory.
            accelerator: An Accelerator object.

        Returns:
            model: The model with optionally loaded weights from checkpoint

        Raise:
88
            FileNotFoundError if `force_resume` is True but checkpoint not found.
89
90
91
        """
        # Determine the network outputs that should be logged
        if hasattr(self.model, "log_vars"):
92
            log_vars = list(self.model.log_vars)
93
94
95
96
        else:
            log_vars = ["objective"]

        if self.resume_epoch > 0:
97
            # Resume from a certain epoch
98
            model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
99
100
            if not os.path.isfile(model_path):
                raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
101
        else:
102
            # Retrieve the last checkpoint
103
104
105
            model_path = model_io.find_last_checkpoint(exp_dir)

        if model_path is not None:
106
107
108
            logger.info(f"Found previous model {model_path}")
            if self.force_resume or self.resume:
                logger.info("Resuming.")
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

                map_location = None
                if accelerator is not None and not accelerator.is_local_main_process:
                    map_location = {
                        "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
                    }
                model_state_dict = torch.load(
                    model_io.get_model_path(model_path), map_location=map_location
                )

                try:
                    self.model.load_state_dict(model_state_dict, strict=True)
                except RuntimeError as e:
                    logger.error(e)
                    logger.info(
124
                        "Cannot load state dict in strict mode! -> trying non-strict"
125
126
                    )
                    self.model.load_state_dict(model_state_dict, strict=False)
127
                self.model.log_vars = log_vars
128
            else:
129
130
                logger.info("Not resuming -> starting from scratch.")
        elif self.force_resume:
131
132
133
            raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")

        return self.model