"test/vscode:/vscode.git/clone" did not exist on "83dbf4075d8e26193111708793390b93a123f82e"
occupancy_field.py 8.76 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
2
3
4
from typing import Callable, List, Tuple, Union

import torch
from torch import nn
Ruilong Li's avatar
Ruilong Li committed
5
6

# from torch_scatter import scatter_max
Ruilong Li's avatar
Ruilong Li committed
7

8

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
9
10
11
def meshgrid3d(
    res: List[int], device: Union[torch.device, str] = "cpu"
) -> torch.Tensor:
12
13
14
    """Create 3D grid coordinates.

    Args:
Matthew Tancik's avatar
Matthew Tancik committed
15
        res: resolutions for {x, y, z} dimensions.
16
17
18
19

    Returns:
        torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
    """
Matthew Tancik's avatar
Matthew Tancik committed
20
    assert len(res) == 3
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    return (
        torch.stack(
            torch.meshgrid(
                [
                    torch.arange(res[0]),
                    torch.arange(res[1]),
                    torch.arange(res[2]),
                ],
                indexing="ij",
            ),
            dim=-1,
        )
        .long()
        .to(device)
    )
Ruilong Li's avatar
Ruilong Li committed
36
37
38


class OccupancyField(nn.Module):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
39
    """Occupancy Field that supports EMA updates. Both 2D and 3D are supported.
Ruilong Li's avatar
Ruilong Li committed
40

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
41
42
    Note:
        Make sure the arguemnts match with the ``num_dim`` -- Either 2D or 3D.
Ruilong Li's avatar
Ruilong Li committed
43
44
45

    Args:
        occ_eval_fn: A Callable function that takes in the un-normalized points x,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
46
47
48
49
            with shape of (N, 2) or (N, 3) (depends on ``num_dim``),
            and outputs the occupancy of those points with shape of (N, 1).
        aabb: Scene bounding box. If ``num_dim=2`` it should be {min_x, min_y,max_x, max_y}.
            If ``num_dim=3`` it should be {min_x, min_y, min_z, max_x, max_y, max_z}.
Ruilong Li's avatar
Ruilong Li committed
50
        resolution: The field resolution. It can either be a int of a list of ints
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
51
52
53
            to specify resolution on each dimention.  If ``num_dim=2`` it is for {res_x, res_y}.
            If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128.
        num_dim: The space dimension. Either 2 or 3. Default is 3.
Matthew Tancik's avatar
Matthew Tancik committed
54
55
56
57
58
59
60

    Attributes:
        aabb: Scene bounding box.
        occ_grid: The occupancy grid. It is a tensor of shape (num_cells,).
        occ_grid_binary: The binary occupancy grid. It is a tensor of shape (num_cells,).
        grid_coords: The grid coordinates. It is a tensor of shape (num_cells, num_dim).
        grid_indices: The grid indices. It is a tensor of shape (num_cells,).
Ruilong Li's avatar
Ruilong Li committed
61
    """
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
62

Matt Tancik's avatar
Matt Tancik committed
63
    aabb: torch.Tensor
Matthew Tancik's avatar
Matthew Tancik committed
64
65
66
67
    occ_grid: torch.Tensor
    occ_grid_binary: torch.Tensor
    grid_coords: torch.Tensor
    grid_indices: torch.Tensor
Ruilong Li's avatar
Ruilong Li committed
68
69
70
71
72

    def __init__(
        self,
        occ_eval_fn: Callable,
        aabb: Union[torch.Tensor, List[float]],
Ruilong Li's avatar
Ruilong Li committed
73
        resolution: Union[int, List[int]] = 128,
Ruilong Li's avatar
Ruilong Li committed
74
75
76
77
78
79
80
81
        num_dim: int = 3,
    ) -> None:
        super().__init__()
        self.occ_eval_fn = occ_eval_fn
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        if not isinstance(resolution, (list, tuple)):
            resolution = [resolution] * num_dim
Ruilong Li's avatar
Ruilong Li committed
82
        assert num_dim in [2, 3], "Currently only supports 2D or 3D field."
Ruilong Li's avatar
Ruilong Li committed
83
84
85
86
87
88
89
90
91
        assert aabb.shape == (
            num_dim * 2,
        ), f"shape of aabb ({aabb.shape}) should be num_dim * 2 ({num_dim * 2})."
        assert (
            len(resolution) == num_dim
        ), f"length of resolution ({len(resolution)}) should be num_dim ({num_dim})."

        self.register_buffer("aabb", aabb)
        self.resolution = resolution
Ruilong Li's avatar
Ruilong Li committed
92
        self.register_buffer("resolution_tensor", torch.tensor(resolution))
Ruilong Li's avatar
Ruilong Li committed
93
        self.num_dim = num_dim
Matthew Tancik's avatar
Matthew Tancik committed
94
        self.num_cells = int(torch.tensor(resolution).prod().item())
Ruilong Li's avatar
Ruilong Li committed
95
96
97
98
99
100
101

        # Stores cell occupancy values ranged in [0, 1].
        occ_grid = torch.zeros(self.num_cells)
        self.register_buffer("occ_grid", occ_grid)
        occ_grid_binary = torch.zeros(self.num_cells, dtype=torch.bool)
        self.register_buffer("occ_grid_binary", occ_grid_binary)

Ruilong Li's avatar
Ruilong Li committed
102
        # Grid coords & indices
103
        grid_coords = meshgrid3d(self.resolution).reshape(self.num_cells, self.num_dim)
Ruilong Li's avatar
Ruilong Li committed
104
105
106
107
108
        self.register_buffer("grid_coords", grid_coords)
        grid_indices = torch.arange(self.num_cells)
        self.register_buffer("grid_indices", grid_indices)

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
109
    def _get_all_cells(self) -> torch.Tensor:
Ruilong Li's avatar
Ruilong Li committed
110
111
112
113
        """Returns all cells of the grid."""
        return self.grid_indices

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
114
115
    def _sample_uniform_and_occupied_cells(self, n: int) -> torch.Tensor:
        """Samples both n uniform and occupied cells."""
Ruilong Li's avatar
Ruilong Li committed
116
117
118
119
120
121
122
123
124
125
126
127
        device = self.occ_grid.device

        uniform_indices = torch.randint(self.num_cells, (n,), device=device)

        occupied_indices = torch.nonzero(self.occ_grid_binary)[:, 0]
        if n < len(occupied_indices):
            selector = torch.randint(len(occupied_indices), (n,), device=device)
            occupied_indices = occupied_indices[selector]
        indices = torch.cat([uniform_indices, occupied_indices], dim=0)
        return indices

    @torch.no_grad()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
128
    def _update(
Ruilong Li's avatar
Ruilong Li committed
129
130
        self,
        step: int,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
131
        occ_thre: float = 0.01,
Ruilong Li's avatar
Ruilong Li committed
132
        ema_decay: float = 0.95,
Ruilong Li's avatar
Ruilong Li committed
133
        warmup_steps: int = 256,
Ruilong Li's avatar
Ruilong Li committed
134
    ) -> None:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
135
        """Update the occ field in the EMA way."""
Ruilong Li's avatar
Ruilong Li committed
136
        # sample cells
Ruilong Li's avatar
Ruilong Li committed
137
138
        if step < warmup_steps:
            indices = self._get_all_cells()
Ruilong Li's avatar
Ruilong Li committed
139
        else:
Ruilong Li's avatar
Ruilong Li committed
140
            N = self.num_cells // 4
Ruilong Li's avatar
Ruilong Li committed
141
            indices = self._sample_uniform_and_occupied_cells(N)
Ruilong Li's avatar
Ruilong Li committed
142
143
144

        # infer occupancy: density * step_size
        grid_coords = self.grid_coords[indices]
Ruilong Li's avatar
Ruilong Li committed
145
        x = (
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
146
            grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
Ruilong Li's avatar
Ruilong Li committed
147
        ) / self.resolution_tensor
Ruilong Li's avatar
Ruilong Li committed
148
149
        bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
        x = x * (bb_max - bb_min) + bb_min
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
150
        occ = self.occ_eval_fn(x).squeeze(-1)
Ruilong Li's avatar
Ruilong Li committed
151
152

        # ema update
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
153
154
155
156
157
        self.occ_grid[indices] = torch.maximum(self.occ_grid[indices] * ema_decay, occ)
        # suppose to use scatter max but emperically it is almost the same.
        # self.occ_grid, _ = scatter_max(
        #     occ, indices, dim=0, out=self.occ_grid * ema_decay
        # )
Ruilong Li's avatar
Ruilong Li committed
158
        self.occ_grid_binary = self.occ_grid > torch.clamp(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
159
            self.occ_grid.mean(), max=occ_thre
Ruilong Li's avatar
Ruilong Li committed
160
161
162
        )

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
163
164
165
166
167
168
169
    def query_occ(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Query the occupancy, given samples.

        Args:
            x: Samples with shape (..., 2) or (..., 3).

        Returns:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
170
            float and binary occupancy values with shape (...) respectively.
Ruilong Li's avatar
Ruilong Li committed
171
172
173
174
        """
        assert (
            x.shape[-1] == self.num_dim
        ), "The samples are not drawn from a proper space!"
Ruilong Li's avatar
Ruilong Li committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        resolution = torch.tensor(self.resolution).to(self.occ_grid.device)

        bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
        x = (x - bb_min) / (bb_max - bb_min)
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)

        grid_coords = torch.floor(x * resolution).long()
        if self.num_dim == 2:
            grid_indices = (
                grid_coords[..., 0] * self.resolution[-1] + grid_coords[..., 1]
            )
        elif self.num_dim == 3:
            grid_indices = (
                grid_coords[..., 0] * self.resolution[-1] * self.resolution[-2]
                + grid_coords[..., 1] * self.resolution[-1]
                + grid_coords[..., 2]
            )
Matthew Tancik's avatar
Matthew Tancik committed
192
193
        else:
            raise NotImplementedError("Currently only supports 2D or 3D field.")
Ruilong Li's avatar
Ruilong Li committed
194
195
        occs = torch.zeros(x.shape[:-1], device=x.device)
        occs[selector] = self.occ_grid[grid_indices[selector]]
Matthew Tancik's avatar
Matthew Tancik committed
196
        occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.bool)
Ruilong Li's avatar
Ruilong Li committed
197
198
199
200
        occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]]
        return occs, occs_binary

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
201
202
203
204
205
    def every_n_step(
        self,
        step: int,
        occ_thre: float = 1e-2,
        ema_decay: float = 0.95,
Ruilong Li's avatar
Ruilong Li committed
206
        warmup_steps: int = 256,
Ruilong Li's avatar
Ruilong Li committed
207
208
        n: int = 16,
    ):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        """Update the field every n steps during training.

        This function is designed for training only. If for some reason you want to
        manually update the field, please use the ``_update()`` function instead.

        Args:
            step: Current training step.
            occ_thre: Threshold to binarize the occupancy field.
            ema_decay: The decay rate for EMA updates.
            warmup_steps: Sample all cells during the warmup stage. After the warmup
                stage we change the sampling strategy to 1/4 unifromly sampled cells
                together with 1/4 occupied cells.
            n: Update the field every n steps.

        Returns:
            None
        """
Ruilong Li's avatar
Ruilong Li committed
226
227
        if not self.training:
            raise RuntimeError(
Ruilong Li's avatar
Ruilong Li committed
228
                "You should only call this function only during training. "
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
229
                "Please call _update() directly if you want to update the "
Ruilong Li's avatar
Ruilong Li committed
230
                "field during inference."
Ruilong Li's avatar
Ruilong Li committed
231
            )
Ruilong Li's avatar
Ruilong Li committed
232
        if step % n == 0 and self.training:
233
            self._update(
Ruilong Li's avatar
Ruilong Li committed
234
                step=step,
235
                occ_thre=occ_thre,
Ruilong Li's avatar
Ruilong Li committed
236
                ema_decay=ema_decay,
Ruilong Li's avatar
Ruilong Li committed
237
                warmup_steps=warmup_steps,
Ruilong Li's avatar
Ruilong Li committed
238
            )