autodecoder.py 6.08 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union

import torch
from pytorch3d.implicitron.tools.config import Configurable


# TODO: probabilistic embeddings?
class Autodecoder(Configurable, torch.nn.Module):
    """
    Autodecoder module

    Settings:
        encoding_dim: Embedding dimension for the decoder.
        n_instances: The maximum number of instances stored by the autodecoder.
        init_scale: Scale factor for the initial autodecoder weights.
        ignore_input: If `True`, optimizes a single code for any input.
    """

    encoding_dim: int = 0
    n_instances: int = 0
    init_scale: float = 1.0
    ignore_input: bool = False

    def __post_init__(self):
        super().__init__()
        if self.n_instances <= 0:
            # Do not init the codes at all in case we have 0 instances.
            return
        self._autodecoder_codes = torch.nn.Embedding(
            self.n_instances,
            self.encoding_dim,
            scale_grad_by_freq=True,
        )
        with torch.no_grad():
            # weight has been initialised from Normal(0, 1)
            self._autodecoder_codes.weight *= self.init_scale

        self._sequence_map = self._build_sequence_map()
        # Make sure to register hooks for correct handling of saving/loading
        # the module's _sequence_map.
        self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
        self._register_state_dict_hook(_save_sequence_map_hook)

    def _build_sequence_map(
        self, sequence_map_dict: Optional[Dict[str, int]] = None
    ) -> Dict[str, int]:
        """
        Args:
            sequence_map_dict: A dictionary used to initialize the sequence_map.

        Returns:
            sequence_map: a dictionary of key: id pairs.
        """
        # increments the counter when asked for a new value
        sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
        if sequence_map_dict is not None:
            # Assign all keys from the loaded sequence_map_dict to self._sequence_map.
            # Since this is done in the original order, it should generate
            # the same set of key:id pairs. We check this with an assert to be sure.
            for x, x_id in sequence_map_dict.items():
                x_id_ = sequence_map[x]
                assert x_id == x_id_
        return sequence_map

    def calc_squared_encoding_norm(self):
        if self.n_instances <= 0:
            return None
76
        return (self._autodecoder_codes.weight**2).mean()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    def get_encoding_dim(self) -> int:
        if self.n_instances <= 0:
            return 0
        return self.encoding_dim

    def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
        """
        Args:
            x: A batch of `N` sequence identifiers. Either a long tensor of size
            `(N,)` keys in [0, n_instances), or a list of `N` string keys that
            are hashed to codes (without collisions).

        Returns:
            codes: A tensor of shape `(N, self.encoding_dim)` containing the
                sequence-specific autodecoder codes.
        """
        if self.n_instances == 0:
            return None

        if self.ignore_input:
            x = ["singleton"]

        if isinstance(x[0], str):
            try:
                x = torch.tensor(
                    # pyre-ignore[29]
                    [self._sequence_map[elem] for elem in x],
                    dtype=torch.long,
                    device=next(self.parameters()).device,
                )
            except StopIteration:
                raise ValueError("Not enough n_instances in the autodecoder")

        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
        return self._autodecoder_codes(x)

    def _load_sequence_map_hook(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        """
        Args:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            prefix (str): the prefix for parameters and buffers used in this
                module
            local_metadata (dict): a dict containing the metadata for this module.
            strict (bool): whether to strictly enforce that the keys in
                :attr:`state_dict` with :attr:`prefix` match the names of
                parameters and buffers in this module
            missing_keys (list of str): if ``strict=True``, add missing keys to
                this list
            unexpected_keys (list of str): if ``strict=True``, add unexpected
                keys to this list
            error_msgs (list of str): error messages should be added to this
                list, and will be reported together in
                :meth:`~torch.nn.Module.load_state_dict`

        Returns:
            Constructed sequence_map if it exists in the state_dict
            else raises a warning only.
        """
        sequence_map_key = prefix + "_sequence_map"
        if sequence_map_key in state_dict:
            sequence_map_dict = state_dict.pop(sequence_map_key)
            self._sequence_map = self._build_sequence_map(
                sequence_map_dict=sequence_map_dict
            )
        else:
            warnings.warn("No sequence map in Autodecoder state dict!")


def _save_sequence_map_hook(
    self,
    state_dict,
    prefix,
    local_metadata,
) -> None:
    """
    Args:
        state_dict (dict): a dict containing parameters and
            persistent buffers.
        prefix (str): the prefix for parameters and buffers used in this
            module
        local_metadata (dict): a dict containing the metadata for this module.
    """
    sequence_map_key = prefix + "_sequence_map"
    sequence_map_dict = dict(self._sequence_map.items())
    state_dict[sequence_map_key] = sequence_map_dict