test_vert_align.py 6.99 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Patrick Labatut's avatar
Patrick Labatut committed
2
3
4
5
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
facebook-github-bot's avatar
facebook-github-bot committed
6
7
8


import unittest
9

facebook-github-bot's avatar
facebook-github-bot committed
10
11
12
13
import torch
import torch.nn.functional as F
from pytorch3d.ops.vert_align import vert_align
from pytorch3d.structures.meshes import Meshes
14
from pytorch3d.structures.pointclouds import Pointclouds
facebook-github-bot's avatar
facebook-github-bot committed
15

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
16
17
from .common_testing import TestCaseMixin

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
18
19

class TestVertAlign(TestCaseMixin, unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
20
21
    @staticmethod
    def vert_align_naive(
22
        feats, verts, return_packed: bool = False, align_corners: bool = True
facebook-github-bot's avatar
facebook-github-bot committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    ):
        """
        Naive implementation of vert_align.
        """
        if torch.is_tensor(feats):
            feats = [feats]
        N = feats[0].shape[0]

        out_feats = []
        # sample every example in the batch separately
        for i in range(N):
            out_i_feats = []
            for feat in feats:
                feats_i = feat[i][None, :, :, :]  # (1, C, H, W)
37
38
39
40
41
42
                if torch.is_tensor(verts):
                    grid = verts[i][None, None, :, :2]  # (1, 1, V, 2)
                elif hasattr(verts, "verts_list"):
                    grid = verts.verts_list()[i][None, None, :, :2]  # (1, 1, V, 2)
                elif hasattr(verts, "points_list"):
                    grid = verts.points_list()[i][None, None, :, :2]  # (1, 1, V, 2)
facebook-github-bot's avatar
facebook-github-bot committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                else:
                    raise ValueError("verts_or_meshes is invalid")
                feat_sampled_i = F.grid_sample(
                    feats_i,
                    grid,
                    mode="bilinear",
                    padding_mode="zeros",
                    align_corners=align_corners,
                )  # (1, C, 1, V)
                feat_sampled_i = feat_sampled_i.squeeze(2).squeeze(0)  # (C, V)
                feat_sampled_i = feat_sampled_i.transpose(1, 0)  # (V, C)
                out_i_feats.append(feat_sampled_i)
            out_i_feats = torch.cat(out_i_feats, 1)  # (V, sum(C))
            out_feats.append(out_i_feats)

        if return_packed:
            out_feats = torch.cat(out_feats, 0)  # (sum(V), sum(C))
        else:
            out_feats = torch.stack(out_feats, 0)  # (N, V, sum(C))
        return out_feats

    @staticmethod
65
66
67
    def init_meshes(
        num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
    ) -> Meshes:
facebook-github-bot's avatar
facebook-github-bot committed
68
69
70
71
72
        device = torch.device("cuda:0")
        verts_list = []
        faces_list = []
        for _ in range(num_meshes):
            verts = (
73
                torch.rand((num_verts, 3), dtype=torch.float32, device=device) * 2.0
facebook-github-bot's avatar
facebook-github-bot committed
74
75
76
77
78
79
80
81
82
83
84
                - 1.0
            )  # verts in the space of [-1, 1]
            faces = torch.randint(
                num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
            )
            verts_list.append(verts)
            faces_list.append(faces)
        meshes = Meshes(verts_list, faces_list)

        return meshes

85
86
87
88
89
90
91
92
93
94
95
96
97
98
    @staticmethod
    def init_pointclouds(num_clouds: int = 10, num_points: int = 1000) -> Pointclouds:
        device = torch.device("cuda:0")
        points_list = []
        for _ in range(num_clouds):
            points = (
                torch.rand((num_points, 3), dtype=torch.float32, device=device) * 2.0
                - 1.0
            )  # points in the space of [-1, 1]
            points_list.append(points)
        pointclouds = Pointclouds(points=points_list)

        return pointclouds

facebook-github-bot's avatar
facebook-github-bot committed
99
    @staticmethod
100
    def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
facebook-github-bot's avatar
facebook-github-bot committed
101
102
        H, W = [14, 28], [14, 28]
        feats = []
103
        for h, w in zip(H, W):
104
            feats.append(torch.rand((batch_size, num_channels, h, w), device=device))
facebook-github-bot's avatar
facebook-github-bot committed
105
106
107
108
109
110
111
112
113
114
115
        return feats

    def test_vert_align_with_meshes(self):
        """
        Test vert align vs naive implementation with meshes.
        """
        meshes = TestVertAlign.init_meshes(10, 1000, 3000)
        feats = TestVertAlign.init_feats(10, 256)

        # feats in list
        out = vert_align(feats, meshes, return_packed=True)
116
        naive_out = TestVertAlign.vert_align_naive(feats, meshes, return_packed=True)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
117
        self.assertClose(out, naive_out)
facebook-github-bot's avatar
facebook-github-bot committed
118
119
120

        # feats as tensor
        out = vert_align(feats[0], meshes, return_packed=True)
121
        naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
122
        self.assertClose(out, naive_out)
facebook-github-bot's avatar
facebook-github-bot committed
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def test_vert_align_with_pointclouds(self):
        """
        Test vert align vs naive implementation with meshes.
        """
        pointclouds = TestVertAlign.init_pointclouds(10, 1000)
        feats = TestVertAlign.init_feats(10, 256)

        # feats in list
        out = vert_align(feats, pointclouds, return_packed=True)
        naive_out = TestVertAlign.vert_align_naive(
            feats, pointclouds, return_packed=True
        )
        self.assertClose(out, naive_out)

        # feats as tensor
        out = vert_align(feats[0], pointclouds, return_packed=True)
        naive_out = TestVertAlign.vert_align_naive(
            feats[0], pointclouds, return_packed=True
        )
        self.assertClose(out, naive_out)

facebook-github-bot's avatar
facebook-github-bot committed
145
146
147
148
149
150
    def test_vert_align_with_verts(self):
        """
        Test vert align vs naive implementation with verts as tensor.
        """
        feats = TestVertAlign.init_feats(10, 256)
        verts = (
151
            torch.rand((10, 100, 3), dtype=torch.float32, device=feats[0].device) * 2.0
facebook-github-bot's avatar
facebook-github-bot committed
152
153
154
155
156
            - 1.0
        )

        # feats in list
        out = vert_align(feats, verts, return_packed=True)
157
        naive_out = TestVertAlign.vert_align_naive(feats, verts, return_packed=True)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
158
        self.assertClose(out, naive_out)
facebook-github-bot's avatar
facebook-github-bot committed
159
160
161

        # feats as tensor
        out = vert_align(feats[0], verts, return_packed=True)
162
        naive_out = TestVertAlign.vert_align_naive(feats[0], verts, return_packed=True)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
163
        self.assertClose(out, naive_out)
facebook-github-bot's avatar
facebook-github-bot committed
164

165
        out2 = vert_align(feats[0], verts, return_packed=True, align_corners=False)
facebook-github-bot's avatar
facebook-github-bot committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        naive_out2 = TestVertAlign.vert_align_naive(
            feats[0], verts, return_packed=True, align_corners=False
        )
        self.assertFalse(torch.allclose(out, out2))
        self.assertTrue(torch.allclose(out2, naive_out2))

    @staticmethod
    def vert_align_with_init(
        num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
    ):
        device = torch.device(device)
        verts_list = []
        faces_list = []
        for _ in range(num_meshes):
180
            verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            faces = torch.randint(
                num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
            )
            verts_list.append(verts)
            faces_list.append(faces)
        meshes = Meshes(verts_list, faces_list)
        feats = TestVertAlign.init_feats(num_meshes, device=device)
        torch.cuda.synchronize()

        def sample_features():
            vert_align(feats, meshes, return_packed=True)
            torch.cuda.synchronize()

        return sample_features