ngp.py 9.53 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li's avatar
Ruilong Li committed
5
6
from typing import Callable, List, Union

7
import numpy as np
Ruilong Li's avatar
Ruilong Li committed
8
9
10
11
12
13
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

try:
    import tinycudann as tcnn
Ruilong Li's avatar
Ruilong Li committed
14
except ImportError as e:
Ruilong Li's avatar
Ruilong Li committed
15
    print(
Ruilong Li's avatar
Ruilong Li committed
16
        f"Error: {e}! "
Ruilong Li's avatar
Ruilong Li committed
17
18
19
20
21
22
        "Please install tinycudann by: "
        "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
    )
    exit()


Ruilong Li's avatar
Ruilong Li committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class _TruncExp(Function):  # pylint: disable=abstract-method
    # Implementation from torch-ngp:
    # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x):  # pylint: disable=arguments-differ
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @custom_bwd
    def backward(ctx, g):  # pylint: disable=arguments-differ
        x = ctx.saved_tensors[0]
        return g * torch.exp(torch.clamp(x, max=15))
Ruilong Li's avatar
Ruilong Li committed
37
38


Ruilong Li's avatar
Ruilong Li committed
39
trunc_exp = _TruncExp.apply
Ruilong Li's avatar
Ruilong Li committed
40

Ruilong Li's avatar
Ruilong Li committed
41

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
42
43
44
def contract_to_unisphere(
    x: torch.Tensor,
    aabb: torch.Tensor,
45
46
    ord: Union[str, int] = 2,
    #  ord: Union[float, int] = float("inf"),
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
47
48
49
50
51
52
    eps: float = 1e-6,
    derivative: bool = False,
):
    aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
    x = (x - aabb_min) / (aabb_max - aabb_min)
    x = x * 2 - 1  # aabb is at [-1, 1]
53
    mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    mask = mag.squeeze(-1) > 1

    if derivative:
        dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
            1 / mag**3 - (2 * mag - 1) / mag**4
        )
        dev[~mask] = 1.0
        dev = torch.clamp(dev, min=eps)
        return dev
    else:
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
        return x


69
70
class NGPRadianceField(torch.nn.Module):
    """Instance-NGP Radiance Field"""
Ruilong Li's avatar
Ruilong Li committed
71
72
73
74
75
76

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        use_viewdirs: bool = True,
Ruilong Li's avatar
Ruilong Li committed
77
        density_activation: Callable = lambda x: trunc_exp(x - 1),
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
78
        unbounded: bool = False,
79
80
        base_resolution: int = 16,
        max_resolution: int = 4096,
Ruilong Li's avatar
Ruilong Li committed
81
82
83
        geo_feat_dim: int = 15,
        n_levels: int = 16,
        log2_hashmap_size: int = 19,
Ruilong Li's avatar
Ruilong Li committed
84
85
86
87
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
88
89
90
91
92
93
94

        # Turns out rectangle aabb will leads to uneven collision so bad performance.
        # We enforce a cube aabb here.
        center = (aabb[..., :num_dim] + aabb[..., num_dim:]) / 2.0
        size = (aabb[..., num_dim:] - aabb[..., :num_dim]).max()
        aabb = torch.cat([center - size / 2.0, center + size / 2.0], dim=-1)

Ruilong Li's avatar
Ruilong Li committed
95
96
97
98
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.use_viewdirs = use_viewdirs
        self.density_activation = density_activation
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
99
        self.unbounded = unbounded
100
101
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
Ruilong Li's avatar
Ruilong Li committed
102
        self.geo_feat_dim = geo_feat_dim
103
104
105
106
107
108
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()
Ruilong Li's avatar
Ruilong Li committed
109
110
111
112
113

        if self.use_viewdirs:
            self.direction_encoding = tcnn.Encoding(
                n_input_dims=num_dim,
                encoding_config={
Ruilong Li's avatar
Ruilong Li committed
114
115
116
117
118
119
120
121
122
                    "otype": "Composite",
                    "nested": [
                        {
                            "n_dims_to_encode": 3,
                            "otype": "SphericalHarmonics",
                            "degree": 4,
                        },
                        # {"otype": "Identity", "n_bins": 4, "degree": 4},
                    ],
Ruilong Li's avatar
Ruilong Li committed
123
124
125
                },
            )

126
127
128
129
130
131
132
133
        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
134
                "base_resolution": base_resolution,
135
136
137
138
139
140
141
142
143
144
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
Ruilong Li's avatar
Ruilong Li committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        if self.geo_feat_dim > 0:
            self.mlp_head = tcnn.Network(
                n_input_dims=(
                    (
                        self.direction_encoding.n_output_dims
                        if self.use_viewdirs
                        else 0
                    )
                    + self.geo_feat_dim
                ),
                n_output_dims=3,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
159
                    "output_activation": "None",
Ruilong Li's avatar
Ruilong Li committed
160
161
162
163
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                },
            )
Ruilong Li's avatar
Ruilong Li committed
164
165

    def query_density(self, x, return_feat: bool = False):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
166
        if self.unbounded:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
167
            x = contract_to_unisphere(x, self.aabb)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
168
169
170
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            x = (x - aabb_min) / (aabb_max - aabb_min)
Ruilong Li's avatar
Ruilong Li committed
171
172
173
174
175
176
177
178
179
180
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
        x = (
            self.mlp_base(x.view(-1, self.num_dim))
            .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
            .to(x)
        )
        density_before_activation, base_mlp_out = torch.split(
            x, [1, self.geo_feat_dim], dim=-1
        )
        density = (
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
181
182
            self.density_activation(density_before_activation)
            * selector[..., None]
Ruilong Li's avatar
Ruilong Li committed
183
184
185
186
187
188
        )
        if return_feat:
            return density, base_mlp_out
        else:
            return density

189
    def _query_rgb(self, dir, embedding, apply_act: bool = True):
Ruilong Li's avatar
Ruilong Li committed
190
191
192
        # tcnn requires directions in the range [0, 1]
        if self.use_viewdirs:
            dir = (dir + 1.0) / 2.0
193
194
            d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
            h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
Ruilong Li's avatar
Ruilong Li committed
195
        else:
196
            h = embedding.reshape(-1, self.geo_feat_dim)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
197
198
        rgb = (
            self.mlp_head(h)
199
            .reshape(list(embedding.shape[:-1]) + [3])
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
200
201
            .to(embedding)
        )
202
203
        if apply_act:
            rgb = torch.sigmoid(rgb)
Ruilong Li's avatar
Ruilong Li committed
204
205
206
207
208
209
210
211
212
213
214
        return rgb

    def forward(
        self,
        positions: torch.Tensor,
        directions: torch.Tensor = None,
    ):
        if self.use_viewdirs and (directions is not None):
            assert (
                positions.shape == directions.shape
            ), f"{positions.shape} v.s. {directions.shape}"
215
216
        density, embedding = self.query_density(positions, return_feat=True)
        rgb = self._query_rgb(directions, embedding=embedding)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        return rgb, density  # type: ignore


class NGPDensityField(torch.nn.Module):
    """Instance-NGP Density Field used for resampling"""

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        base_resolution: int = 16,
        max_resolution: int = 128,
        n_levels: int = 5,
        log2_hashmap_size: int = 17,
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.density_activation = density_activation
        self.unbounded = unbounded
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()

        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )

    def forward(self, positions: torch.Tensor):
        if self.unbounded:
            positions = contract_to_unisphere(positions, self.aabb)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        density_before_activation = (
            self.mlp_base(positions.view(-1, self.num_dim))
            .view(list(positions.shape[:-1]) + [1])
            .to(positions)
        )
        density = (
            self.density_activation(density_before_activation)
            * selector[..., None]
        )
        return density