occupancy_field.py 8.38 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
2
3
4
from typing import Callable, List, Tuple, Union

import torch
from torch import nn
Ruilong Li's avatar
Ruilong Li committed
5
6

# from torch_scatter import scatter_max
Ruilong Li's avatar
Ruilong Li committed
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
    """Create 3D grid coordinates.

    Args:
        res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions.

    Returns:
        torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
    """
    return (
        torch.stack(
            torch.meshgrid(
                [
                    torch.arange(res[0]),
                    torch.arange(res[1]),
                    torch.arange(res[2]),
                ],
                indexing="ij",
            ),
            dim=-1,
        )
        .long()
        .to(device)
    )
Ruilong Li's avatar
Ruilong Li committed
33
34
35


class OccupancyField(nn.Module):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
36
    """Occupancy Field that supports EMA updates. Both 2D and 3D are supported.
Ruilong Li's avatar
Ruilong Li committed
37

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
38
39
    Note:
        Make sure the arguemnts match with the ``num_dim`` -- Either 2D or 3D.
Ruilong Li's avatar
Ruilong Li committed
40
41
42

    Args:
        occ_eval_fn: A Callable function that takes in the un-normalized points x,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
43
44
45
46
            with shape of (N, 2) or (N, 3) (depends on ``num_dim``),
            and outputs the occupancy of those points with shape of (N, 1).
        aabb: Scene bounding box. If ``num_dim=2`` it should be {min_x, min_y,max_x, max_y}.
            If ``num_dim=3`` it should be {min_x, min_y, min_z, max_x, max_y, max_z}.
Ruilong Li's avatar
Ruilong Li committed
47
        resolution: The field resolution. It can either be a int of a list of ints
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
48
49
50
            to specify resolution on each dimention.  If ``num_dim=2`` it is for {res_x, res_y}.
            If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128.
        num_dim: The space dimension. Either 2 or 3. Default is 3.
Ruilong Li's avatar
Ruilong Li committed
51
    """
Ruilong Li's avatar
Ruilong Li committed
52
53
54
55
56

    def __init__(
        self,
        occ_eval_fn: Callable,
        aabb: Union[torch.Tensor, List[float]],
Ruilong Li's avatar
Ruilong Li committed
57
        resolution: Union[int, List[int]] = 128,
Ruilong Li's avatar
Ruilong Li committed
58
59
60
61
62
63
64
65
        num_dim: int = 3,
    ) -> None:
        super().__init__()
        self.occ_eval_fn = occ_eval_fn
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        if not isinstance(resolution, (list, tuple)):
            resolution = [resolution] * num_dim
Ruilong Li's avatar
Ruilong Li committed
66
        assert num_dim in [2, 3], "Currently only supports 2D or 3D field."
Ruilong Li's avatar
Ruilong Li committed
67
68
69
70
71
72
73
74
75
        assert aabb.shape == (
            num_dim * 2,
        ), f"shape of aabb ({aabb.shape}) should be num_dim * 2 ({num_dim * 2})."
        assert (
            len(resolution) == num_dim
        ), f"length of resolution ({len(resolution)}) should be num_dim ({num_dim})."

        self.register_buffer("aabb", aabb)
        self.resolution = resolution
Ruilong Li's avatar
Ruilong Li committed
76
        self.register_buffer("resolution_tensor", torch.tensor(resolution))
Ruilong Li's avatar
Ruilong Li committed
77
78
79
80
81
82
83
84
85
86
87
88
89
        self.num_dim = num_dim
        self.num_cells = torch.tensor(resolution).prod().item()

        # Stores cell occupancy values ranged in [0, 1].
        occ_grid = torch.zeros(self.num_cells)
        self.register_buffer("occ_grid", occ_grid)
        occ_grid_binary = torch.zeros(self.num_cells, dtype=torch.bool)
        self.register_buffer("occ_grid_binary", occ_grid_binary)

        # Used for thresholding occ_grid
        occ_grid_mean = occ_grid.mean()
        self.register_buffer("occ_grid_mean", occ_grid_mean)

Ruilong Li's avatar
Ruilong Li committed
90
        # Grid coords & indices
91
        grid_coords = meshgrid3d(self.resolution).reshape(self.num_cells, self.num_dim)
Ruilong Li's avatar
Ruilong Li committed
92
93
94
95
96
        self.register_buffer("grid_coords", grid_coords)
        grid_indices = torch.arange(self.num_cells)
        self.register_buffer("grid_indices", grid_indices)

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
97
    def _get_all_cells(self) -> torch.Tensor:
Ruilong Li's avatar
Ruilong Li committed
98
99
100
101
        """Returns all cells of the grid."""
        return self.grid_indices

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
102
103
    def _sample_uniform_and_occupied_cells(self, n: int) -> torch.Tensor:
        """Samples both n uniform and occupied cells."""
Ruilong Li's avatar
Ruilong Li committed
104
105
106
107
108
109
110
111
112
113
114
115
        device = self.occ_grid.device

        uniform_indices = torch.randint(self.num_cells, (n,), device=device)

        occupied_indices = torch.nonzero(self.occ_grid_binary)[:, 0]
        if n < len(occupied_indices):
            selector = torch.randint(len(occupied_indices), (n,), device=device)
            occupied_indices = occupied_indices[selector]
        indices = torch.cat([uniform_indices, occupied_indices], dim=0)
        return indices

    @torch.no_grad()
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
116
    def _update(
Ruilong Li's avatar
Ruilong Li committed
117
118
        self,
        step: int,
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
119
        occ_thre: float = 0.01,
Ruilong Li's avatar
Ruilong Li committed
120
        ema_decay: float = 0.95,
Ruilong Li's avatar
Ruilong Li committed
121
        warmup_steps: int = 256,
Ruilong Li's avatar
Ruilong Li committed
122
    ) -> None:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
123
        """Update the occ field in the EMA way."""
Ruilong Li's avatar
Ruilong Li committed
124
        # sample cells
Ruilong Li's avatar
Ruilong Li committed
125
126
        if step < warmup_steps:
            indices = self._get_all_cells()
Ruilong Li's avatar
Ruilong Li committed
127
        else:
Ruilong Li's avatar
Ruilong Li committed
128
            N = self.num_cells // 4
Ruilong Li's avatar
Ruilong Li committed
129
            indices = self._sample_uniform_and_occupied_cells(N)
Ruilong Li's avatar
Ruilong Li committed
130
131
132
133

        # infer occupancy: density * step_size
        tmp_occ_grid = -torch.ones_like(self.occ_grid)
        grid_coords = self.grid_coords[indices]
Ruilong Li's avatar
Ruilong Li committed
134
135
136
        x = (
            grid_coords + torch.rand_like(grid_coords.float())
        ) / self.resolution_tensor
Ruilong Li's avatar
Ruilong Li committed
137
138
        bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
        x = x * (bb_max - bb_min) + bb_min
Ruilong Li's avatar
Ruilong Li committed
139
        tmp_occ = self.occ_eval_fn(x).squeeze(-1)
Ruilong Li's avatar
Ruilong Li committed
140
141
        tmp_occ_grid[indices] = tmp_occ
        # tmp_occ_grid, _ = scatter_max(tmp_occ, indices, dim=0, out=tmp_occ_grid)
Ruilong Li's avatar
Ruilong Li committed
142
143
144
145
146
147
148

        # ema update
        ema_mask = (self.occ_grid >= 0) & (tmp_occ_grid >= 0)
        self.occ_grid[ema_mask] = torch.maximum(
            self.occ_grid[ema_mask] * ema_decay, tmp_occ_grid[ema_mask]
        )
        self.occ_grid_mean = self.occ_grid.mean()
Ruilong Li's avatar
Ruilong Li committed
149
        self.occ_grid_binary = self.occ_grid > torch.clamp(
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
150
            self.occ_grid_mean, max=occ_thre
Ruilong Li's avatar
Ruilong Li committed
151
152
153
        )

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
154
155
156
157
158
159
160
    def query_occ(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Query the occupancy, given samples.

        Args:
            x: Samples with shape (..., 2) or (..., 3).

        Returns:
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
161
            float and binary occupancy values with shape (...) respectively.
Ruilong Li's avatar
Ruilong Li committed
162
163
164
165
        """
        assert (
            x.shape[-1] == self.num_dim
        ), "The samples are not drawn from a proper space!"
Ruilong Li's avatar
Ruilong Li committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        resolution = torch.tensor(self.resolution).to(self.occ_grid.device)

        bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
        x = (x - bb_min) / (bb_max - bb_min)
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)

        grid_coords = torch.floor(x * resolution).long()
        if self.num_dim == 2:
            grid_indices = (
                grid_coords[..., 0] * self.resolution[-1] + grid_coords[..., 1]
            )
        elif self.num_dim == 3:
            grid_indices = (
                grid_coords[..., 0] * self.resolution[-1] * self.resolution[-2]
                + grid_coords[..., 1] * self.resolution[-1]
                + grid_coords[..., 2]
            )
        occs = torch.zeros(x.shape[:-1], device=x.device)
        occs[selector] = self.occ_grid[grid_indices[selector]]
        occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=bool)
        occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]]
        return occs, occs_binary

    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
190
191
192
193
194
    def every_n_step(
        self,
        step: int,
        occ_thre: float = 1e-2,
        ema_decay: float = 0.95,
Ruilong Li's avatar
Ruilong Li committed
195
        warmup_steps: int = 256,
Ruilong Li's avatar
Ruilong Li committed
196
197
        n: int = 16,
    ):
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        """Update the field every n steps during training.

        This function is designed for training only. If for some reason you want to
        manually update the field, please use the ``_update()`` function instead.

        Args:
            step: Current training step.
            occ_thre: Threshold to binarize the occupancy field.
            ema_decay: The decay rate for EMA updates.
            warmup_steps: Sample all cells during the warmup stage. After the warmup
                stage we change the sampling strategy to 1/4 unifromly sampled cells
                together with 1/4 occupied cells.
            n: Update the field every n steps.

        Returns:
            None
        """
Ruilong Li's avatar
Ruilong Li committed
215
216
        if not self.training:
            raise RuntimeError(
Ruilong Li's avatar
Ruilong Li committed
217
                "You should only call this function only during training. "
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
218
                "Please call _update() directly if you want to update the "
Ruilong Li's avatar
Ruilong Li committed
219
                "field during inference."
Ruilong Li's avatar
Ruilong Li committed
220
            )
Ruilong Li's avatar
Ruilong Li committed
221
        if step % n == 0 and self.training:
222
            self._update(
Ruilong Li's avatar
Ruilong Li committed
223
                step=step,
224
                occ_thre=occ_thre,
Ruilong Li's avatar
Ruilong Li committed
225
                ema_decay=ema_decay,
Ruilong Li's avatar
Ruilong Li committed
226
                warmup_steps=warmup_steps,
Ruilong Li's avatar
Ruilong Li committed
227
            )