nerf_mlp.py 9.88 KB
Newer Older
YirongYan's avatar
YirongYan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    """The MLP module used in NerfDet.

    Args:
        input_dim (int): The number of input tensor channels.
        output_dim (int): The number of output tensor channels.
        net_depth (int): The depth of the MLP. Defaults to 8.
        net_width (int): The width of the MLP. Defaults to 256.
        skip_layer (int): The layer to add skip layers to. Defaults to 4.

        hidden_init (Callable): The initialize method of the hidden layers.
        hidden_activation (Callable): The activation function of hidden
            layers, defaults to ReLU.
        output_enabled (bool): If true, the output layers will be used.
            Defaults to True.
        output_init (Optional): The initialize method of the output layer.
        output_activation(Optional): The activation function of output layers.
        bias_enabled (Bool): If true, the bias will be used.
        bias_init (Callable): The initialize method of the bias.
            Defaults to True.
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int = None,
        net_depth: int = 8,
        net_width: int = 256,
        skip_layer: int = 4,
        hidden_init: Callable = nn.init.xavier_uniform_,
        hidden_activation: Callable = nn.ReLU(),
        output_enabled: bool = True,
        output_init: Optional[Callable] = nn.init.xavier_uniform_,
        output_activation: Optional[Callable] = nn.Identity(),
        bias_enabled: bool = True,
        bias_init: Callable = nn.init.zeros_,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.net_depth = net_depth
        self.net_width = net_width
        self.skip_layer = skip_layer
        self.hidden_init = hidden_init
        self.hidden_activation = hidden_activation
        self.output_enabled = output_enabled
        self.output_init = output_init
        self.output_activation = output_activation
        self.bias_enabled = bias_enabled
        self.bias_init = bias_init

        self.hidden_layers = nn.ModuleList()
        in_features = self.input_dim
        for i in range(self.net_depth):
            self.hidden_layers.append(
                nn.Linear(in_features, self.net_width, bias=bias_enabled))
            if (self.skip_layer is not None) and (i % self.skip_layer
                                                  == 0) and (i > 0):
                in_features = self.net_width + self.input_dim
            else:
                in_features = self.net_width
        if self.output_enabled:
            self.output_layer = nn.Linear(
                in_features, self.output_dim, bias=bias_enabled)
        else:
            self.output_dim = in_features

        self.initialize()

    def initialize(self):

        def init_func_hidden(m):
            if isinstance(m, nn.Linear):
                if self.hidden_init is not None:
                    self.hidden_init(m.weight)
                if self.bias_enabled and self.bias_init is not None:
                    self.bias_init(m.bias)

        self.hidden_layers.apply(init_func_hidden)
        if self.output_enabled:

            def init_func_output(m):
                if isinstance(m, nn.Linear):
                    if self.output_init is not None:
                        self.output_init(m.weight)
                    if self.bias_enabled and self.bias_init is not None:
                        self.bias_init(m.bias)

            self.output_layer.apply(init_func_output)

    def forward(self, x):
        inputs = x
        for i in range(self.net_depth):
            x = self.hidden_layers[i](x)
            x = self.hidden_activation(x)
            if (self.skip_layer is not None) and (i % self.skip_layer
                                                  == 0) and (i > 0):
                x = torch.cat([x, inputs], dim=-1)
        if self.output_enabled:
            x = self.output_layer(x)
            x = self.output_activation(x)
        return x


class DenseLayer(MLP):

    def __init__(self, input_dim, output_dim, **kwargs):
        super().__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            net_depth=0,  # no hidden layers
            **kwargs,
        )


class NerfMLP(nn.Module):
    """The Nerf-MLP Module.

    Args:
        input_dim (int): The number of input tensor channels.
        condition_dim (int): The number of condition tensor channels.
        feature_dim (int): The number of feature channels. Defaults to 0.
        net_depth (int): The depth of the MLP. Defaults to 8.
        net_width (int): The width of the MLP. Defaults to 256.
        skip_layer (int): The layer to add skip layers to. Defaults to 4.
        net_depth_condition (int): The depth of the second part of MLP.
            Defaults to 1.
        net_width_condition (int): The width of the second part of MLP.
            Defaults to 128.
    """

    def __init__(
        self,
        input_dim: int,
        condition_dim: int,
        feature_dim: int = 0,
        net_depth: int = 8,
        net_width: int = 256,
        skip_layer: int = 4,
        net_depth_condition: int = 1,
        net_width_condition: int = 128,
    ):
        super().__init__()
        self.base = MLP(
            input_dim=input_dim + feature_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            output_enabled=False,
        )
        hidden_features = self.base.output_dim
        self.sigma_layer = DenseLayer(hidden_features, 1)

        if condition_dim > 0:
            self.bottleneck_layer = DenseLayer(hidden_features, net_width)
            self.rgb_layer = MLP(
                input_dim=net_width + condition_dim,
                output_dim=3,
                net_depth=net_depth_condition,
                net_width=net_width_condition,
                skip_layer=None,
            )
        else:
            self.rgb_layer = DenseLayer(hidden_features, 3)

    def query_density(self, x, features=None):
        """Calculate the raw sigma."""
        if features is not None:
            x = self.base(torch.cat([x, features], dim=-1))
        else:
            x = self.base(x)
        raw_sigma = self.sigma_layer(x)
        return raw_sigma

    def forward(self, x, condition=None, features=None):
        if features is not None:
            x = self.base(torch.cat([x, features], dim=-1))
        else:
            x = self.base(x)
        raw_sigma = self.sigma_layer(x)
        if condition is not None:
            if condition.shape[:-1] != x.shape[:-1]:
                num_rays, n_dim = condition.shape
                condition = condition.view(
                    [num_rays] + [1] * (x.dim() - condition.dim()) +
                    [n_dim]).expand(list(x.shape[:-1]) + [n_dim])
            bottleneck = self.bottleneck_layer(x)
            x = torch.cat([bottleneck, condition], dim=-1)
        raw_rgb = self.rgb_layer(x)
        return raw_rgb, raw_sigma


class SinusoidalEncoder(nn.Module):
    """Sinusodial Positional Encoder used in NeRF."""

    def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
        super().__init__()
        self.x_dim = x_dim
        self.min_deg = min_deg
        self.max_deg = max_deg
        self.use_identity = use_identity
        self.register_buffer(
            'scales', torch.tensor([2**i for i in range(min_deg, max_deg)]))

    @property
    def latent_dim(self) -> int:
        return (int(self.use_identity) +
                (self.max_deg - self.min_deg) * 2) * self.x_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.max_deg == self.min_deg:
            return x
        xb = torch.reshape(
            (x[Ellipsis, None, :] * self.scales[:, None]),
            list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
        )
        latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
        if self.use_identity:
            latent = torch.cat([x] + [latent], dim=-1)
        return latent


class VanillaNeRF(nn.Module):
    """The Nerf-MLP with the positional encoder.

    Args:
        net_depth (int): The depth of the MLP. Defaults to 8.
        net_width (int): The width of the MLP. Defaults to 256.
        skip_layer (int): The layer to add skip layers to. Defaults to 4.
        feature_dim (int): The number of feature channels. Defaults to 0.
        net_depth_condition (int): The depth of the second part of MLP.
            Defaults to 1.
        net_width_condition (int): The width of the second part of MLP.
            Defaults to 128.
    """

    def __init__(self,
                 net_depth: int = 8,
                 net_width: int = 256,
                 skip_layer: int = 4,
                 feature_dim: int = 0,
                 net_depth_condition: int = 1,
                 net_width_condition: int = 128):
        super().__init__()
        self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
        self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
        self.mlp = NerfMLP(
            input_dim=self.posi_encoder.latent_dim,
            condition_dim=self.view_encoder.latent_dim,
            feature_dim=feature_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            net_depth_condition=net_depth_condition,
            net_width_condition=net_width_condition,
        )

    def query_density(self, x, features=None):
        x = self.posi_encoder(x)
        sigma = self.mlp.query_density(x, features)
        return F.relu(sigma)

    def forward(self, x, condition=None, features=None):
        x = self.posi_encoder(x)
        if condition is not None:
            condition = self.view_encoder(condition)
        rgb, sigma = self.mlp(x, condition=condition, features=features)
        return torch.sigmoid(rgb), F.relu(sigma)