model_factory.py 4.53 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

7
8
# pyre-unsafe

9
10
import logging
import os
11
from typing import Optional
12
13
14
15
16

import torch.optim

from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
17
from pytorch3d.implicitron.tools import model_io
18
19
20
21
22
23
24
25
26
27
28
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats

logger = logging.getLogger(__name__)


class ModelFactoryBase(ReplaceableBase):
29
30
31

    resume: bool = True  # resume from the last checkpoint

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    def __call__(self, **kwargs) -> ImplicitronModelBase:
        """
        Initialize the model (possibly from a previously saved state).

        Returns: An instance of ImplicitronModelBase.
        """
        raise NotImplementedError()

    def load_stats(self, **kwargs) -> Stats:
        """
        Initialize or load a Stats object.
        """
        raise NotImplementedError()


@registry.register
48
class ImplicitronModelFactory(ModelFactoryBase):
49
50
51
52
53
54
55
    """
    A factory class that initializes an implicit rendering model.

    Members:
        model: An ImplicitronModelBase object.
        resume: If True, attempt to load the last checkpoint from `exp_dir`
            passed to __call__. Failure to do so will return a model with ini-
56
            tial weights unless `force_resume` is True.
57
58
        resume_epoch: If `resume` is True: Resume a model at this epoch, or if
            `resume_epoch` <= 0, then resume from the latest checkpoint.
59
60
61
        force_resume: If True, throw a FileNotFoundError if `resume` is True but
            a model checkpoint cannot be found.

62
63
    """

64
    # pyre-fixme[13]: Attribute `model` is never initialized.
65
66
    model: ImplicitronModelBase
    model_class_type: str = "GenericModel"
67
    resume: bool = True
68
    resume_epoch: int = -1
69
    force_resume: bool = False
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    def __post_init__(self):
        run_auto_creation(self)

    def __call__(
        self,
        exp_dir: str,
        accelerator: Optional[Accelerator] = None,
    ) -> ImplicitronModelBase:
        """
        Returns an instance of `ImplicitronModelBase`, possibly loaded from a
        checkpoint (if self.resume, self.resume_epoch specify so).

        Args:
            exp_dir: Root experiment directory.
            accelerator: An Accelerator object.

        Returns:
            model: The model with optionally loaded weights from checkpoint

        Raise:
91
            FileNotFoundError if `force_resume` is True but checkpoint not found.
92
93
94
        """
        # Determine the network outputs that should be logged
        if hasattr(self.model, "log_vars"):
95
            log_vars = list(self.model.log_vars)
96
97
98
99
        else:
            log_vars = ["objective"]

        if self.resume_epoch > 0:
100
            # Resume from a certain epoch
101
            model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
102
103
            if not os.path.isfile(model_path):
                raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
104
        else:
105
            # Retrieve the last checkpoint
106
107
108
            model_path = model_io.find_last_checkpoint(exp_dir)

        if model_path is not None:
109
110
111
            logger.info(f"Found previous model {model_path}")
            if self.force_resume or self.resume:
                logger.info("Resuming.")
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

                map_location = None
                if accelerator is not None and not accelerator.is_local_main_process:
                    map_location = {
                        "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
                    }
                model_state_dict = torch.load(
                    model_io.get_model_path(model_path), map_location=map_location
                )

                try:
                    self.model.load_state_dict(model_state_dict, strict=True)
                except RuntimeError as e:
                    logger.error(e)
                    logger.info(
127
                        "Cannot load state dict in strict mode! -> trying non-strict"
128
129
                    )
                    self.model.load_state_dict(model_state_dict, strict=False)
130
                self.model.log_vars = log_vars
131
            else:
132
133
                logger.info("Not resuming -> starting from scratch.")
        elif self.force_resume:
134
135
136
            raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")

        return self.model