ngp.py 6.42 KB
Newer Older
1
2
3
4
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

Ruilong Li's avatar
Ruilong Li committed
5
import math
Ruilong Li's avatar
Ruilong Li committed
6
7
8
9
10
11
12
13
from typing import Callable, List, Union

import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

try:
    import tinycudann as tcnn
Ruilong Li's avatar
Ruilong Li committed
14
except ImportError as e:
Ruilong Li's avatar
Ruilong Li committed
15
    print(
Ruilong Li's avatar
Ruilong Li committed
16
        f"Error: {e}! "
Ruilong Li's avatar
Ruilong Li committed
17
18
19
20
21
22
        "Please install tinycudann by: "
        "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
    )
    exit()


Ruilong Li's avatar
Ruilong Li committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class _TruncExp(Function):  # pylint: disable=abstract-method
    # Implementation from torch-ngp:
    # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x):  # pylint: disable=arguments-differ
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @custom_bwd
    def backward(ctx, g):  # pylint: disable=arguments-differ
        x = ctx.saved_tensors[0]
        return g * torch.exp(torch.clamp(x, max=15))
Ruilong Li's avatar
Ruilong Li committed
37
38


Ruilong Li's avatar
Ruilong Li committed
39
trunc_exp = _TruncExp.apply
Ruilong Li's avatar
Ruilong Li committed
40

Ruilong Li's avatar
Ruilong Li committed
41

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def contract_to_unisphere(
    x: torch.Tensor,
    aabb: torch.Tensor,
    eps: float = 1e-6,
    derivative: bool = False,
):
    aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
    x = (x - aabb_min) / (aabb_max - aabb_min)
    x = x * 2 - 1  # aabb is at [-1, 1]
    mag = x.norm(dim=-1, keepdim=True)
    mask = mag.squeeze(-1) > 1

    if derivative:
        dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
            1 / mag**3 - (2 * mag - 1) / mag**4
        )
        dev[~mask] = 1.0
        dev = torch.clamp(dev, min=eps)
        return dev
    else:
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
        return x


Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
67
class NGPradianceField(torch.nn.Module):
Ruilong Li's avatar
Ruilong Li committed
68
    """Instance-NGP radiance Field"""
Ruilong Li's avatar
Ruilong Li committed
69
70
71
72
73
74

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        use_viewdirs: bool = True,
Ruilong Li's avatar
Ruilong Li committed
75
        density_activation: Callable = lambda x: trunc_exp(x - 1),
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
76
        unbounded: bool = False,
Ruilong Li's avatar
Ruilong Li committed
77
        hidden_dim: int = 64,
Ruilong Li's avatar
Ruilong Li committed
78
79
        geo_feat_dim: int = 15,
        n_levels: int = 16,
Ruilong Li's avatar
Ruilong Li committed
80
81
        max_res: int = 1024,
        base_res: int = 16,
Ruilong Li's avatar
Ruilong Li committed
82
        log2_hashmap_size: int = 19,
Ruilong Li's avatar
Ruilong Li committed
83
84
85
86
87
88
89
90
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.use_viewdirs = use_viewdirs
        self.density_activation = density_activation
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
91
        self.unbounded = unbounded
Ruilong Li's avatar
Ruilong Li committed
92

Ruilong Li's avatar
Ruilong Li committed
93
        self.geo_feat_dim = geo_feat_dim
Ruilong Li's avatar
Ruilong Li committed
94
95
96
        per_level_scale = math.exp(
            (math.log(max_res) - math.log(base_res)) / (n_levels - 1)
        )
Ruilong Li's avatar
Ruilong Li committed
97
98
99
100
101

        if self.use_viewdirs:
            self.direction_encoding = tcnn.Encoding(
                n_input_dims=num_dim,
                encoding_config={
Ruilong Li's avatar
Ruilong Li committed
102
103
104
105
106
107
108
109
110
                    "otype": "Composite",
                    "nested": [
                        {
                            "n_dims_to_encode": 3,
                            "otype": "SphericalHarmonics",
                            "degree": 4,
                        },
                        # {"otype": "Identity", "n_bins": 4, "degree": 4},
                    ],
Ruilong Li's avatar
Ruilong Li committed
111
112
113
114
115
116
117
118
                },
            )

        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
Ruilong Li's avatar
Ruilong Li committed
119
                "n_levels": n_levels,
Ruilong Li's avatar
Ruilong Li committed
120
                "n_features_per_level": 2,
Ruilong Li's avatar
Ruilong Li committed
121
                "log2_hashmap_size": log2_hashmap_size,
Ruilong Li's avatar
Ruilong Li committed
122
                "base_resolution": base_res,
Ruilong Li's avatar
Ruilong Li committed
123
124
125
126
127
128
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
Ruilong Li's avatar
Ruilong Li committed
129
                "n_neurons": hidden_dim,
Ruilong Li's avatar
Ruilong Li committed
130
131
132
                "n_hidden_layers": 1,
            },
        )
Ruilong Li's avatar
Ruilong Li committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        if self.geo_feat_dim > 0:
            self.mlp_head = tcnn.Network(
                n_input_dims=(
                    (
                        self.direction_encoding.n_output_dims
                        if self.use_viewdirs
                        else 0
                    )
                    + self.geo_feat_dim
                ),
                n_output_dims=3,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "Sigmoid",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                },
            )
Ruilong Li's avatar
Ruilong Li committed
152
153

    def query_density(self, x, return_feat: bool = False):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
154
        if self.unbounded:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
155
            x = contract_to_unisphere(x, self.aabb)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
156
157
158
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            x = (x - aabb_min) / (aabb_max - aabb_min)
Ruilong Li's avatar
Ruilong Li committed
159
160
161
162
163
164
165
166
167
168
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
        x = (
            self.mlp_base(x.view(-1, self.num_dim))
            .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
            .to(x)
        )
        density_before_activation, base_mlp_out = torch.split(
            x, [1, self.geo_feat_dim], dim=-1
        )
        density = (
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
169
170
            self.density_activation(density_before_activation)
            * selector[..., None]
Ruilong Li's avatar
Ruilong Li committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        )
        if return_feat:
            return density, base_mlp_out
        else:
            return density

    def _query_rgb(self, dir, embedding):
        # tcnn requires directions in the range [0, 1]
        if self.use_viewdirs:
            dir = (dir + 1.0) / 2.0
            d = self.direction_encoding(dir.view(-1, dir.shape[-1]))
            h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1)
        else:
            h = embedding.view(-1, self.geo_feat_dim)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
185
186
187
188
189
        rgb = (
            self.mlp_head(h)
            .view(list(embedding.shape[:-1]) + [3])
            .to(embedding)
        )
Ruilong Li's avatar
Ruilong Li committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        return rgb

    def forward(
        self,
        positions: torch.Tensor,
        directions: torch.Tensor = None,
    ):
        if self.use_viewdirs and (directions is not None):
            assert (
                positions.shape == directions.shape
            ), f"{positions.shape} v.s. {directions.shape}"
            density, embedding = self.query_density(positions, return_feat=True)
            rgb = self._query_rgb(directions, embedding=embedding)
        return rgb, density