rasterize_meshes.py 19.1 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


from typing import Optional

6
7
import numpy as np
import torch
facebook-github-bot's avatar
facebook-github-bot committed
8
9
from pytorch3d import _C

10

facebook-github-bot's avatar
facebook-github-bot committed
11
# TODO make the epsilon user configurable
Nikhila Ravi's avatar
Nikhila Ravi committed
12
kEpsilon = 1e-8
facebook-github-bot's avatar
facebook-github-bot committed
13

14
15
16
17
# Maxinum number of faces per bins for
# coarse-to-fine rasterization
kMaxFacesPerBin = 22

facebook-github-bot's avatar
facebook-github-bot committed
18
19
20
21
22
23
24
25
26

def rasterize_meshes(
    meshes,
    image_size: int = 256,
    blur_radius: float = 0.0,
    faces_per_pixel: int = 8,
    bin_size: Optional[int] = None,
    max_faces_per_bin: Optional[int] = None,
    perspective_correct: bool = False,
27
    cull_backfaces: bool = False,
facebook-github-bot's avatar
facebook-github-bot committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
):
    """
    Rasterize a batch of meshes given the shape of the desired output image.
    Each mesh is rasterized onto a separate image of shape
    (image_size, image_size).

    Args:
        meshes: A Meshes object representing a batch of meshes, batch size N.
        image_size: Size in pixels of the output raster image for each mesh
            in the batch. Assumes square images.
        blur_radius: Float distance in the range [0, 2] used to expand the face
            bounding boxes for rasterization. Setting blur radius
            results in blurred edges around the shape instead of a
            hard boundary. Set to 0 for no blur.
        faces_per_pixel (Optional): Number of faces to save per pixel, returning
            the nearest faces_per_pixel points along the z-axis.
        bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
            bin_size=0 uses naive rasterization; setting bin_size=None attempts to
            set it heuristically based on the shape of the input. This should not
            affect the output, but can affect the speed of the forward pass.
        faces_per_bin: Only applicable when using coarse-to-fine rasterization
            (bin_size > 0); this is the maxiumum number of faces allowed within each
            bin. If more than this many faces actually fall into a bin, an error
            will be raised. This should not affect the output values, but can affect
            the memory usage in the forward pass.
53
        perspective_correct: Bool, Whether to apply perspective correction when computing
facebook-github-bot's avatar
facebook-github-bot committed
54
            barycentric coordinates for pixels.
55
56
57
58
59
60
61
62
        cull_backfaces: Bool, Whether to only rasterize mesh faces which are
            visible to the camera.  This assumes that vertices of
            front-facing triangles are ordered in an anti-clockwise
            fashion, and triangles that face away from the camera are
            in a clockwise order relative to the current view
            direction. NOTE: This will only work if the mesh faces are
            consistently defined with counter-clockwise ordering when
            viewed from the outside.
facebook-github-bot's avatar
facebook-github-bot committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    Returns:
        4-element tuple containing

        - **pix_to_face**: LongTensor of shape
          (N, image_size, image_size, faces_per_pixel)
          giving the indices of the nearest faces at each pixel,
          sorted in ascending z-order.
          Concretely ``pix_to_face[n, y, x, k] = f`` means that
          ``faces_verts[f]`` is the kth closest face (in the z-direction)
          to pixel (y, x). Pixels that are hit by fewer than
          faces_per_pixel are padded with -1.
        - **zbuf**: FloatTensor of shape (N, image_size, image_size, faces_per_pixel)
          giving the NDC z-coordinates of the nearest faces at each pixel,
          sorted in ascending z-order.
          Concretely, if ``pix_to_face[n, y, x, k] = f`` then
          ``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
          faces_per_pixel are padded with -1.
        - **barycentric**: FloatTensor of shape
          (N, image_size, image_size, faces_per_pixel, 3)
          giving the barycentric coordinates in NDC units of the
          nearest faces at each pixel, sorted in ascending z-order.
          Concretely, if ``pix_to_face[n, y, x, k] = f`` then
          ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives
          the barycentric coords for pixel (y, x) relative to the face
          defined by ``face_verts[f]``. Pixels hit by fewer than
          faces_per_pixel are padded with -1.
        - **pix_dists**: FloatTensor of shape
          (N, image_size, image_size, faces_per_pixel)
          giving the signed Euclidean distance (in NDC units) in the
          x/y plane of each point closest to the pixel. Concretely if
          ``pix_to_face[n, y, x, k] = f`` then ``pix_dists[n, y, x, k]`` is the
          squared distance between the pixel (y, x) and the face given
          by vertices ``face_verts[f]``. Pixels hit with fewer than
          ``faces_per_pixel`` are padded with -1.
    """
    verts_packed = meshes.verts_packed()
    faces_packed = meshes.faces_packed()
    face_verts = verts_packed[faces_packed]
    mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
    num_faces_per_mesh = meshes.num_faces_per_mesh()

    # TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
    if bin_size is None:
        if not verts_packed.is_cuda:
            # Binned CPU rasterization is not supported.
            bin_size = 0
        else:
            # TODO better heuristics for bin size.
            if image_size <= 64:
                bin_size = 8
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            else:
                # Heuristic based formula maps image_size -> bin_size as follows:
                # image_size < 64 -> 8
                # 16 < image_size < 256 -> 16
                # 256 < image_size < 512 -> 32
                # 512 < image_size < 1024 -> 64
                # 1024 < image_size < 2048 -> 128
                bin_size = int(2 ** max(np.ceil(np.log2(image_size)) - 4, 4))

    if bin_size != 0:
        # There is a limit on the number of faces per bin in the cuda kernel.
        faces_per_bin = 1 + (image_size - 1) // bin_size
        if faces_per_bin >= kMaxFacesPerBin:
            raise ValueError(
                "bin_size too small, number of faces per bin must be less than %d; got %d"
                % (kMaxFacesPerBin, faces_per_bin)
            )
facebook-github-bot's avatar
facebook-github-bot committed
131
132

    if max_faces_per_bin is None:
Nikhila Ravi's avatar
Nikhila Ravi committed
133
        max_faces_per_bin = int(max(10000, meshes._F / 5))
facebook-github-bot's avatar
facebook-github-bot committed
134

135
    # pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`.
facebook-github-bot's avatar
facebook-github-bot committed
136
137
138
139
140
141
142
143
144
145
    return _RasterizeFaceVerts.apply(
        face_verts,
        mesh_to_face_first_idx,
        num_faces_per_mesh,
        image_size,
        blur_radius,
        faces_per_pixel,
        bin_size,
        max_faces_per_bin,
        perspective_correct,
146
        cull_backfaces,
facebook-github-bot's avatar
facebook-github-bot committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    )


class _RasterizeFaceVerts(torch.autograd.Function):
    """
    Torch autograd wrapper for forward and backward pass of rasterize_meshes
    implemented in C++/CUDA.

    Args:
        face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions
            for faces in all the meshes in the batch. Concretely,
            face_verts[f, i] = [x, y, z] gives the coordinates for the
            ith vertex of the fth face. These vertices are expected to
            be in NDC coordinates in the range [-1, 1].
        mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
            faces_verts of the first face in each mesh in
            the batch.
        num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
            for each mesh in the batch.
        image_size, blur_radius, faces_per_pixel: same as rasterize_meshes.
        perspective_correct: same as rasterize_meshes.
168
        cull_backfaces: same as rasterize_meshes.
facebook-github-bot's avatar
facebook-github-bot committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

    Returns:
        same as rasterize_meshes function.
    """

    @staticmethod
    def forward(
        ctx,
        face_verts,
        mesh_to_face_first_idx,
        num_faces_per_mesh,
        image_size: int = 256,
        blur_radius: float = 0.01,
        faces_per_pixel: int = 0,
        bin_size: int = 0,
        max_faces_per_bin: int = 0,
        perspective_correct: bool = False,
186
        cull_backfaces: bool = False,
facebook-github-bot's avatar
facebook-github-bot committed
187
    ):
188
        # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
facebook-github-bot's avatar
facebook-github-bot committed
189
190
191
192
193
194
195
196
197
198
        pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
            face_verts,
            mesh_to_face_first_idx,
            num_faces_per_mesh,
            image_size,
            blur_radius,
            faces_per_pixel,
            bin_size,
            max_faces_per_bin,
            perspective_correct,
199
            cull_backfaces,
facebook-github-bot's avatar
facebook-github-bot committed
200
201
        )
        ctx.save_for_backward(face_verts, pix_to_face)
202
        ctx.mark_non_differentiable(pix_to_face)
facebook-github-bot's avatar
facebook-github-bot committed
203
204
205
206
        ctx.perspective_correct = perspective_correct
        return pix_to_face, zbuf, barycentric_coords, dists

    @staticmethod
207
    def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists):
facebook-github-bot's avatar
facebook-github-bot committed
208
209
210
211
212
213
214
215
216
        grad_face_verts = None
        grad_mesh_to_face_first_idx = None
        grad_num_faces_per_mesh = None
        grad_image_size = None
        grad_radius = None
        grad_faces_per_pixel = None
        grad_bin_size = None
        grad_max_faces_per_bin = None
        grad_perspective_correct = None
217
        grad_cull_backfaces = None
facebook-github-bot's avatar
facebook-github-bot committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        face_verts, pix_to_face = ctx.saved_tensors
        grad_face_verts = _C.rasterize_meshes_backward(
            face_verts,
            pix_to_face,
            grad_zbuf,
            grad_barycentric_coords,
            grad_dists,
            ctx.perspective_correct,
        )
        grads = (
            grad_face_verts,
            grad_mesh_to_face_first_idx,
            grad_num_faces_per_mesh,
            grad_image_size,
            grad_radius,
            grad_faces_per_pixel,
            grad_bin_size,
            grad_max_faces_per_bin,
            grad_perspective_correct,
237
            grad_cull_backfaces,
facebook-github-bot's avatar
facebook-github-bot committed
238
239
240
241
        )
        return grads


242
243
244
245
246
def pix_to_ndc(i, S):
    # NDC x-offset + (i * pixel_width + half_pixel_width)
    return -1 + (2 * i + 1.0) / S


facebook-github-bot's avatar
facebook-github-bot committed
247
248
249
250
251
252
def rasterize_meshes_python(
    meshes,
    image_size: int = 256,
    blur_radius: float = 0.0,
    faces_per_pixel: int = 8,
    perspective_correct: bool = False,
253
    cull_backfaces: bool = False,
facebook-github-bot's avatar
facebook-github-bot committed
254
255
256
257
258
259
260
261
262
263
264
):
    """
    Naive PyTorch implementation of mesh rasterization with the same inputs and
    outputs as the rasterize_meshes function.

    This function is not optimized and is implemented as a comparison for the
    C++/CUDA implementations.
    """
    N = len(meshes)
    # Assume only square images.
    # TODO(T52813608) extend support for non-square images.
Nikhila Ravi's avatar
Nikhila Ravi committed
265
    H, W = image_size, image_size
facebook-github-bot's avatar
facebook-github-bot committed
266
267
268
269
270
271
272
273
274
275
276
277
278
    K = faces_per_pixel
    device = meshes.device

    verts_packed = meshes.verts_packed()
    faces_packed = meshes.faces_packed()
    faces_verts = verts_packed[faces_packed]
    mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
    num_faces_per_mesh = meshes.num_faces_per_mesh()

    # Intialize output tensors.
    face_idxs = torch.full(
        (N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
    )
279
    zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
280
281
282
283
284
285
286
287
    bary_coords = torch.full(
        (N, H, W, K, 3), fill_value=-1, dtype=torch.float32, device=device
    )
    pix_dists = torch.full(
        (N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
    )

    # Calculate all face bounding boxes.
288
    # pyre-fixme[16]: `Tuple` has no attribute `values`.
facebook-github-bot's avatar
facebook-github-bot committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
    x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
    y_mins = torch.min(faces_verts[:, :, 1], dim=1, keepdim=True).values
    y_maxs = torch.max(faces_verts[:, :, 1], dim=1, keepdim=True).values

    # Expand by blur radius.
    x_mins = x_mins - np.sqrt(blur_radius) - kEpsilon
    x_maxs = x_maxs + np.sqrt(blur_radius) + kEpsilon
    y_mins = y_mins - np.sqrt(blur_radius) - kEpsilon
    y_maxs = y_maxs + np.sqrt(blur_radius) + kEpsilon

    # Loop through meshes in the batch.
    for n in range(N):
        face_start_idx = mesh_to_face_first_idx[n]
        face_stop_idx = face_start_idx + num_faces_per_mesh[n]
304

facebook-github-bot's avatar
facebook-github-bot committed
305
306
        # Iterate through the horizontal lines of the image from top to bottom.
        for yi in range(H):
307
308
309
310
311
            # Y coordinate of one end of the image. Reverse the ordering
            # of yi so that +Y is pointing up in the image.
            yfix = H - 1 - yi
            yf = pix_to_ndc(yfix, H)

facebook-github-bot's avatar
facebook-github-bot committed
312
313
            # Iterate through pixels on this horizontal line, left to right.
            for xi in range(W):
314
315
316
                # X coordinate of one end of the image. Reverse the ordering
                # of xi so that +X is pointing to the left in the image.
                xfix = W - 1 - xi
317
                xf = pix_to_ndc(xfix, W)
facebook-github-bot's avatar
facebook-github-bot committed
318
319
320
321
322
323
324
                top_k_points = []

                # Check whether each face in the mesh affects this pixel.
                for f in range(face_start_idx, face_stop_idx):
                    face = faces_verts[f].squeeze()
                    v0, v1, v2 = face.unbind(0)

325
326
327
328
329
330
                    face_area = edge_function(v0, v1, v2)

                    # Ignore triangles facing away from the camera.
                    back_face = face_area < 0
                    if cull_backfaces and back_face:
                        continue
facebook-github-bot's avatar
facebook-github-bot committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

                    # Ignore faces which have zero area.
                    if face_area == 0.0:
                        continue

                    outside_bbox = (
                        xf < x_mins[f]
                        or xf > x_maxs[f]
                        or yf < y_mins[f]
                        or yf > y_maxs[f]
                    )

                    # Check if pixel is outside of face bbox.
                    if outside_bbox:
                        continue

                    # Compute barycentric coordinates and pixel z distance.
348
                    pxy = torch.tensor([xf, yf], dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

                    bary = barycentric_coordinates(pxy, v0[:2], v1[:2], v2[:2])
                    if perspective_correct:
                        z0, z1, z2 = v0[2], v1[2], v2[2]
                        l0, l1, l2 = bary[0], bary[1], bary[2]
                        top0 = l0 * z1 * z2
                        top1 = z0 * l1 * z2
                        top2 = z0 * z1 * l2
                        bot = top0 + top1 + top2
                        bary = torch.stack([top0 / bot, top1 / bot, top2 / bot])
                    pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2]

                    # Check if point is behind the image.
                    if pz < 0:
                        continue

                    # Calculate signed 2D distance from point to face.
                    # Points inside the triangle have negative distance.
                    dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
                    inside = all(x > 0.0 for x in bary)

                    signed_dist = dist * -1.0 if inside else dist

                    # Add an epsilon to prevent errors when comparing distance
                    # to blur radius.
                    if not inside and dist >= blur_radius:
                        continue

                    top_k_points.append((pz, f, bary, signed_dist))
                    top_k_points.sort()
                    if len(top_k_points) > K:
                        top_k_points = top_k_points[:K]

                # Save to output tensors.
                for k, (pz, f, bary, dist) in enumerate(top_k_points):
                    zbuf[n, yi, xi, k] = pz
                    face_idxs[n, yi, xi, k] = f
                    bary_coords[n, yi, xi, k, 0] = bary[0]
                    bary_coords[n, yi, xi, k, 1] = bary[1]
                    bary_coords[n, yi, xi, k, 2] = bary[2]
                    pix_dists[n, yi, xi, k] = dist

    return face_idxs, zbuf, bary_coords, pix_dists


def edge_function(p, v0, v1):
    r"""
    Determines whether a point p is on the right side of a 2D line segment
    given by the end points v0, v1.

    Args:
        p: (x, y) Coordinates of a point.
        v0, v1: (x, y) Coordinates of the end points of the edge.

    Returns:
        area: The signed area of the parallelogram given by the vectors

              .. code-block:: python

408
409
                  B = p - v0
                  A = v1 - v0
facebook-github-bot's avatar
facebook-github-bot committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

                        v1 ________
                          /\      /
                      A  /  \    /
                        /    \  /
                    v0 /______\/
                          B    p

             The area can also be interpreted as the cross product A x B.
             If the sign of the area is positive, the point p is on the
             right side of the edge. Negative area indicates the point is on
             the left side of the edge. i.e. for an edge v1 - v0

             .. code-block:: python

                             v1
                            /
                           /
                    -     /    +
                         /
                        /
                      v0
    """
    return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0])


def barycentric_coordinates(p, v0, v1, v2):
    """
    Compute the barycentric coordinates of a point relative to a triangle.

    Args:
        p: Coordinates of a point.
        v0, v1, v2: Coordinates of the triangle vertices.

    Returns
        bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
    """
    area = edge_function(v2, v0, v1) + kEpsilon  # 2 x face area.
    w0 = edge_function(p, v1, v2) / area
    w1 = edge_function(p, v2, v0) / area
    w2 = edge_function(p, v0, v1) / area
    return (w0, w1, w2)


def point_line_distance(p, v0, v1):
    """
    Return minimum distance between line segment (v1 - v0) and point p.

    Args:
        p: Coordinates of a point.
        v0, v1: Coordinates of the end points of the line segment.

    Returns:
        non-square distance to the boundary of the triangle.

    Consider the line extending the segment - this can be parameterized as
    ``v0 + t (v1 - v0)``.

    First find the projection of point p onto the line. It falls where
    ``t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2``
    where . is the dot product.

    The parameter t is clamped from [0, 1] to handle points outside the
    segment (v1 - v0).

    Once the projection of the point on the segment is known, the distance from
    p to the projection gives the minimum distance to the segment.
    """
    if p.shape != v0.shape != v1.shape:
        raise ValueError("All points must have the same number of coordinates")

    v1v0 = v1 - v0
    l2 = v1v0.dot(v1v0)  # |v1 - v0|^2
483
484
    if l2 <= kEpsilon:
        return (p - v1).dot(p - v1)  # v0 == v1
facebook-github-bot's avatar
facebook-github-bot committed
485

Nikhila Ravi's avatar
Nikhila Ravi committed
486
    t = v1v0.dot(p - v0) / l2
facebook-github-bot's avatar
facebook-github-bot committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
    t = torch.clamp(t, min=0.0, max=1.0)
    p_proj = v0 + t * v1v0
    delta_p = p_proj - p
    return delta_p.dot(delta_p)


def point_triangle_distance(p, v0, v1, v2):
    """
    Return shortest distance between a point and a triangle.

    Args:
        p: Coordinates of a point.
        v0, v1, v2: Coordinates of the three triangle vertices.

    Returns:
        shortest absolute distance from the point to the triangle.
    """
    if p.shape != v0.shape != v1.shape != v2.shape:
        raise ValueError("All points must have the same number of coordinates")

    e01_dist = point_line_distance(p, v0, v1)
    e02_dist = point_line_distance(p, v0, v2)
    e12_dist = point_line_distance(p, v1, v2)
    edge_dists_min = torch.min(torch.min(e01_dist, e02_dist), e12_dist)

    return edge_dists_min