cubify.py 8.05 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import torch
import torch.nn.functional as F
from pytorch3d.structures import Meshes


def unravel_index(idx, dims) -> torch.Tensor:
    r"""
    Equivalent to np.unravel_index
    Args:
      idx: A LongTensor whose elements are indices into the
          flattened version of an array of dimensions dims.
      dims: The shape of the array to be indexed.
    Implemented only for dims=(N, H, W, D)
    """
    if len(dims) != 4:
        raise ValueError("Expects a 4-element list.")
    N, H, W, D = dims
    n = torch.div(idx, H * W * D)
    h = torch.div(idx - n * H * W * D, W * D)
    w = torch.div(idx - n * H * W * D - h * W * D, D)
    d = idx - n * H * W * D - h * W * D - w * D
    return torch.stack((n, h, w, d), dim=1)


def ravel_index(idx, dims) -> torch.Tensor:
    """
    Computes the linear index in an array of shape dims.
    It performs the reverse functionality of unravel_index
    Args:
      idx: A LongTensor of shape (N, 3). Each row corresponds to indices into an
          array of dimensions dims.
      dims: The shape of the array to be indexed.
    Implemented only for dims=(H, W, D)
    """
    if len(dims) != 3:
        raise ValueError("Expects a 3-element list")
    if idx.shape[1] != 3:
        raise ValueError("Expects an index tensor of shape Nx3")
    H, W, D = dims
    linind = idx[:, 0] * W * D + idx[:, 1] * D + idx[:, 2]
    return linind


@torch.no_grad()
Georgia Gkioxari's avatar
Georgia Gkioxari committed
48
def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
facebook-github-bot's avatar
facebook-github-bot committed
49
50
51
52
53
54
55
56
    r"""
    Converts a voxel to a mesh by replacing each occupied voxel with a cube
    consisting of 12 faces and 8 vertices. Shared vertices are merged, and
    internal faces are removed.
    Args:
      voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
      thresh: A scalar threshold. If a voxel occupancy is larger than
          thresh, the voxel is considered occupied.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
57
58
59
60
      device: The device of the output meshes
      align: Defines the alignment of the mesh vertices and the grid locations.
          Has to be one of {"topleft", "corner", "center"}. See below for explanation.
          Default is "topleft".
facebook-github-bot's avatar
facebook-github-bot committed
61
62
    Returns:
      meshes: A Meshes object of the corresponding meshes.
Georgia Gkioxari's avatar
Georgia Gkioxari committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


    The alignment between the vertices of the cubified mesh and the voxel locations (or pixels)
    is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid:

                X---X----         X-------X        ---------
                |   |   |         |   |   |        | X | X |
                X---X----         ---------        ---------
                |   |   |         |   |   |        | X | X |
                ---------         X-------X        ---------

                 topleft           corner            center

    In the figure, X denote the grid locations and the squares represent the added cuboids.
    When `align="topleft"`, then the top left corner of each cuboid corresponds to the
    pixel coordinate of the input grid.
    When `align="corner"`, then the corners of the output mesh span the whole grid.
    When `align="center"`, then the grid locations form the center of the cuboids.
facebook-github-bot's avatar
facebook-github-bot committed
81
82
83
84
85
    """

    if device is None:
        device = voxels.device

Georgia Gkioxari's avatar
Georgia Gkioxari committed
86
87
88
    if align not in ["topleft", "corner", "center"]:
        raise ValueError("Align mode must be one of (topleft, corner, center).")

facebook-github-bot's avatar
facebook-github-bot committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    if len(voxels) == 0:
        return Meshes(verts=[], faces=[])

    N, D, H, W = voxels.size()
    # vertices corresponding to a unit cube: 8x3
    cube_verts = torch.tensor(
        [
            [0, 0, 0],
            [0, 0, 1],
            [0, 1, 0],
            [0, 1, 1],
            [1, 0, 0],
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1],
        ],
        dtype=torch.int64,
        device=device,
    )

    # faces corresponding to a unit cube: 12x3
    cube_faces = torch.tensor(
        [
            [0, 1, 2],
            [1, 3, 2],  # left face: 0, 1
            [2, 3, 6],
            [3, 7, 6],  # bottom face: 2, 3
            [0, 2, 6],
            [0, 6, 4],  # front face: 4, 5
            [0, 5, 1],
            [0, 4, 5],  # up face: 6, 7
            [6, 7, 5],
            [6, 5, 4],  # right face: 8, 9
            [1, 7, 3],
            [1, 5, 7],  # back face: 10, 11
        ],
        dtype=torch.int64,
        device=device,
    )

    wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2)
    wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1)
    wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1)

    voxelt = voxels.ge(thresh).float()
    # N x 1 x D x H x W
    voxelt = voxelt.view(N, 1, D, H, W)

    # N x 1 x (D-1) x (H-1) x (W-1)
    voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float()
    voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float()
    voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float()

    # 12 x N x 1 x D x H x W
    faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device)

    # add left face
    faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x
    faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x
    # add bottom face
    faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y
    faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y
    # add front face
    faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z
    faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z
    # add up face
    faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y
    faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y
    # add right face
    faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x
    faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x
    # add back face
    faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z
    faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z

    faces_idx *= voxelt

    # N x H x W x D x 12
    faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1)
    # (NHWD) x 12
    faces_idx = faces_idx.contiguous()
    faces_idx = faces_idx.view(-1, cube_faces.size(0))

    # boolean to linear index
    # NF x 2
Georgia Gkioxari's avatar
Georgia Gkioxari committed
174
    linind = torch.nonzero(faces_idx, as_tuple=False)
facebook-github-bot's avatar
facebook-github-bot committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    # NF x 4
    nyxz = unravel_index(linind[:, 0], (N, H, W, D))

    # NF x 3: faces
    faces = torch.index_select(cube_faces, 0, linind[:, 1])

    grid_faces = []
    for d in range(cube_faces.size(1)):
        # NF x 3
        xyz = torch.index_select(cube_verts, 0, faces[:, d])
        permute_idx = torch.tensor([1, 0, 2], device=device)
        yxz = torch.index_select(xyz, 1, permute_idx)
        yxz += nyxz[:, 1:]
        # NF x 1
        temp = ravel_index(yxz, (H + 1, W + 1, D + 1))
        grid_faces.append(temp)
    # NF x 3
    grid_faces = torch.stack(grid_faces, dim=1)

    y, x, z = torch.meshgrid(
        torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
    )
    y = y.to(device=device, dtype=torch.float32)
    x = x.to(device=device, dtype=torch.float32)
    z = z.to(device=device, dtype=torch.float32)
Georgia Gkioxari's avatar
Georgia Gkioxari committed
200
201
202
203
204
205
206
207
208
209
210

    if align == "center":
        x = x - 0.5
        y = y - 0.5
        z = z - 0.5

    margin = 0.0 if align == "corner" else 1.0
    y = y * 2.0 / (H - margin) - 1.0
    x = x * 2.0 / (W - margin) - 1.0
    z = z * 2.0 / (D - margin) - 1.0

facebook-github-bot's avatar
facebook-github-bot committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    # ((H+1)(W+1)(D+1)) x 3
    grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)

    if len(nyxz) == 0:
        verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N
        faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N
        return Meshes(verts=verts_list, faces=faces_list)

    num_verts = grid_verts.size(0)
    grid_faces += nyxz[:, 0].view(-1, 1) * num_verts
    idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device)

    idleverts.scatter_(0, grid_faces.flatten(), 0)
    grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts
    split_size = torch.bincount(nyxz[:, 0], minlength=N)
    faces_list = list(torch.split(grid_faces, split_size.tolist(), 0))

    idleverts = idleverts.view(N, num_verts)
    idlenum = idleverts.cumsum(1)

    verts_list = [
Georgia Gkioxari's avatar
Georgia Gkioxari committed
232
        grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
facebook-github-bot's avatar
facebook-github-bot committed
233
234
        for n in range(N)
    ]
235
    faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
facebook-github-bot's avatar
facebook-github-bot committed
236
237

    return Meshes(verts=verts_list, faces=faces_list)