ngp.py 9.21 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li's avatar
Ruilong Li committed
5
6
from typing import Callable, List, Union

7
import numpy as np
Ruilong Li's avatar
Ruilong Li committed
8
9
10
11
12
13
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

try:
    import tinycudann as tcnn
Ruilong Li's avatar
Ruilong Li committed
14
except ImportError as e:
Ruilong Li's avatar
Ruilong Li committed
15
    print(
Ruilong Li's avatar
Ruilong Li committed
16
        f"Error: {e}! "
Ruilong Li's avatar
Ruilong Li committed
17
18
19
20
21
22
        "Please install tinycudann by: "
        "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
    )
    exit()


Ruilong Li's avatar
Ruilong Li committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class _TruncExp(Function):  # pylint: disable=abstract-method
    # Implementation from torch-ngp:
    # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x):  # pylint: disable=arguments-differ
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @custom_bwd
    def backward(ctx, g):  # pylint: disable=arguments-differ
        x = ctx.saved_tensors[0]
        return g * torch.exp(torch.clamp(x, max=15))
Ruilong Li's avatar
Ruilong Li committed
37
38


Ruilong Li's avatar
Ruilong Li committed
39
trunc_exp = _TruncExp.apply
Ruilong Li's avatar
Ruilong Li committed
40

Ruilong Li's avatar
Ruilong Li committed
41

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
42
43
44
def contract_to_unisphere(
    x: torch.Tensor,
    aabb: torch.Tensor,
45
46
    ord: Union[str, int] = 2,
    #  ord: Union[float, int] = float("inf"),
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
47
48
49
50
51
52
    eps: float = 1e-6,
    derivative: bool = False,
):
    aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
    x = (x - aabb_min) / (aabb_max - aabb_min)
    x = x * 2 - 1  # aabb is at [-1, 1]
53
    mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    mask = mag.squeeze(-1) > 1

    if derivative:
        dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
            1 / mag**3 - (2 * mag - 1) / mag**4
        )
        dev[~mask] = 1.0
        dev = torch.clamp(dev, min=eps)
        return dev
    else:
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
        return x


69
70
class NGPRadianceField(torch.nn.Module):
    """Instance-NGP Radiance Field"""
Ruilong Li's avatar
Ruilong Li committed
71
72
73
74
75
76

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        use_viewdirs: bool = True,
Ruilong Li's avatar
Ruilong Li committed
77
        density_activation: Callable = lambda x: trunc_exp(x - 1),
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
78
        unbounded: bool = False,
79
80
        base_resolution: int = 16,
        max_resolution: int = 4096,
Ruilong Li's avatar
Ruilong Li committed
81
82
83
        geo_feat_dim: int = 15,
        n_levels: int = 16,
        log2_hashmap_size: int = 19,
Ruilong Li's avatar
Ruilong Li committed
84
85
86
87
88
89
90
91
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.use_viewdirs = use_viewdirs
        self.density_activation = density_activation
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
92
        self.unbounded = unbounded
93
94
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
Ruilong Li's avatar
Ruilong Li committed
95
        self.geo_feat_dim = geo_feat_dim
96
97
98
99
100
101
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()
Ruilong Li's avatar
Ruilong Li committed
102
103
104
105
106

        if self.use_viewdirs:
            self.direction_encoding = tcnn.Encoding(
                n_input_dims=num_dim,
                encoding_config={
Ruilong Li's avatar
Ruilong Li committed
107
108
109
110
111
112
113
114
115
                    "otype": "Composite",
                    "nested": [
                        {
                            "n_dims_to_encode": 3,
                            "otype": "SphericalHarmonics",
                            "degree": 4,
                        },
                        # {"otype": "Identity", "n_bins": 4, "degree": 4},
                    ],
Ruilong Li's avatar
Ruilong Li committed
116
117
118
                },
            )

119
120
121
122
123
124
125
126
        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
127
                "base_resolution": base_resolution,
128
129
130
131
132
133
134
135
136
137
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
Ruilong Li's avatar
Ruilong Li committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        if self.geo_feat_dim > 0:
            self.mlp_head = tcnn.Network(
                n_input_dims=(
                    (
                        self.direction_encoding.n_output_dims
                        if self.use_viewdirs
                        else 0
                    )
                    + self.geo_feat_dim
                ),
                n_output_dims=3,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
152
                    "output_activation": "None",
Ruilong Li's avatar
Ruilong Li committed
153
154
155
156
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                },
            )
Ruilong Li's avatar
Ruilong Li committed
157
158

    def query_density(self, x, return_feat: bool = False):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
159
        if self.unbounded:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
160
            x = contract_to_unisphere(x, self.aabb)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
161
162
163
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            x = (x - aabb_min) / (aabb_max - aabb_min)
Ruilong Li's avatar
Ruilong Li committed
164
165
166
167
168
169
170
171
172
173
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
        x = (
            self.mlp_base(x.view(-1, self.num_dim))
            .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
            .to(x)
        )
        density_before_activation, base_mlp_out = torch.split(
            x, [1, self.geo_feat_dim], dim=-1
        )
        density = (
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
174
175
            self.density_activation(density_before_activation)
            * selector[..., None]
Ruilong Li's avatar
Ruilong Li committed
176
177
178
179
180
181
        )
        if return_feat:
            return density, base_mlp_out
        else:
            return density

182
    def _query_rgb(self, dir, embedding, apply_act: bool = True):
Ruilong Li's avatar
Ruilong Li committed
183
184
185
        # tcnn requires directions in the range [0, 1]
        if self.use_viewdirs:
            dir = (dir + 1.0) / 2.0
186
187
            d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
            h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
Ruilong Li's avatar
Ruilong Li committed
188
        else:
189
            h = embedding.reshape(-1, self.geo_feat_dim)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
190
191
        rgb = (
            self.mlp_head(h)
192
            .reshape(list(embedding.shape[:-1]) + [3])
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
193
194
            .to(embedding)
        )
195
196
        if apply_act:
            rgb = torch.sigmoid(rgb)
Ruilong Li's avatar
Ruilong Li committed
197
198
199
200
201
202
203
204
205
206
207
208
209
        return rgb

    def forward(
        self,
        positions: torch.Tensor,
        directions: torch.Tensor = None,
    ):
        if self.use_viewdirs and (directions is not None):
            assert (
                positions.shape == directions.shape
            ), f"{positions.shape} v.s. {directions.shape}"
            density, embedding = self.query_density(positions, return_feat=True)
            rgb = self._query_rgb(directions, embedding=embedding)
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        return rgb, density  # type: ignore


class NGPDensityField(torch.nn.Module):
    """Instance-NGP Density Field used for resampling"""

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        base_resolution: int = 16,
        max_resolution: int = 128,
        n_levels: int = 5,
        log2_hashmap_size: int = 17,
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.density_activation = density_activation
        self.unbounded = unbounded
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()

        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )

    def forward(self, positions: torch.Tensor):
        if self.unbounded:
            positions = contract_to_unisphere(positions, self.aabb)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        density_before_activation = (
            self.mlp_base(positions.view(-1, self.num_dim))
            .view(list(positions.shape[:-1]) + [1])
            .to(positions)
        )
        density = (
            self.density_activation(density_before_activation)
            * selector[..., None]
        )
        return density