"vscode:/vscode.git/clone" did not exist on "ac02b1b19009ec161bccaf4763ce4030ec94b2be"
test_vert_align.py 5.79 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import unittest
import torch
import torch.nn.functional as F

from pytorch3d.ops.vert_align import vert_align
from pytorch3d.structures.meshes import Meshes


class TestVertAlign(unittest.TestCase):
    @staticmethod
    def vert_align_naive(
        feats,
        verts_or_meshes,
        return_packed: bool = False,
        align_corners: bool = True,
    ):
        """
        Naive implementation of vert_align.
        """
        if torch.is_tensor(feats):
            feats = [feats]
        N = feats[0].shape[0]

        out_feats = []
        # sample every example in the batch separately
        for i in range(N):
            out_i_feats = []
            for feat in feats:
                feats_i = feat[i][None, :, :, :]  # (1, C, H, W)
                if torch.is_tensor(verts_or_meshes):
                    grid = verts_or_meshes[i][None, None, :, :2]  # (1, 1, V, 2)
                elif hasattr(verts_or_meshes, "verts_list"):
                    grid = verts_or_meshes.verts_list()[i][
                        None, None, :, :2
                    ]  # (1, 1, V, 2)
                else:
                    raise ValueError("verts_or_meshes is invalid")
                feat_sampled_i = F.grid_sample(
                    feats_i,
                    grid,
                    mode="bilinear",
                    padding_mode="zeros",
                    align_corners=align_corners,
                )  # (1, C, 1, V)
                feat_sampled_i = feat_sampled_i.squeeze(2).squeeze(0)  # (C, V)
                feat_sampled_i = feat_sampled_i.transpose(1, 0)  # (V, C)
                out_i_feats.append(feat_sampled_i)
            out_i_feats = torch.cat(out_i_feats, 1)  # (V, sum(C))
            out_feats.append(out_i_feats)

        if return_packed:
            out_feats = torch.cat(out_feats, 0)  # (sum(V), sum(C))
        else:
            out_feats = torch.stack(out_feats, 0)  # (N, V, sum(C))
        return out_feats

    @staticmethod
    def init_meshes(
        num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
    ):
        device = torch.device("cuda:0")
        verts_list = []
        faces_list = []
        for _ in range(num_meshes):
            verts = (
                torch.rand((num_verts, 3), dtype=torch.float32, device=device)
                * 2.0
                - 1.0
            )  # verts in the space of [-1, 1]
            faces = torch.randint(
                num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
            )
            verts_list.append(verts)
            faces_list.append(faces)
        meshes = Meshes(verts_list, faces_list)

        return meshes

    @staticmethod
    def init_feats(
        batch_size: int = 10, num_channels: int = 256, device: str = "cuda"
    ):
        H, W = [14, 28], [14, 28]
        feats = []
        for (h, w) in zip(H, W):
            feats.append(
                torch.rand((batch_size, num_channels, h, w), device=device)
            )
        return feats

    def test_vert_align_with_meshes(self):
        """
        Test vert align vs naive implementation with meshes.
        """
        meshes = TestVertAlign.init_meshes(10, 1000, 3000)
        feats = TestVertAlign.init_feats(10, 256)

        # feats in list
        out = vert_align(feats, meshes, return_packed=True)
        naive_out = TestVertAlign.vert_align_naive(
            feats, meshes, return_packed=True
        )
        self.assertTrue(torch.allclose(out, naive_out))

        # feats as tensor
        out = vert_align(feats[0], meshes, return_packed=True)
        naive_out = TestVertAlign.vert_align_naive(
            feats[0], meshes, return_packed=True
        )
        self.assertTrue(torch.allclose(out, naive_out))

    def test_vert_align_with_verts(self):
        """
        Test vert align vs naive implementation with verts as tensor.
        """
        feats = TestVertAlign.init_feats(10, 256)
        verts = (
            torch.rand(
                (10, 100, 3), dtype=torch.float32, device=feats[0].device
            )
            * 2.0
            - 1.0
        )

        # feats in list
        out = vert_align(feats, verts, return_packed=True)
        naive_out = TestVertAlign.vert_align_naive(
            feats, verts, return_packed=True
        )
        self.assertTrue(torch.allclose(out, naive_out))

        # feats as tensor
        out = vert_align(feats[0], verts, return_packed=True)
        naive_out = TestVertAlign.vert_align_naive(
            feats[0], verts, return_packed=True
        )
        self.assertTrue(torch.allclose(out, naive_out))

        out2 = vert_align(
            feats[0], verts, return_packed=True, align_corners=False
        )
        naive_out2 = TestVertAlign.vert_align_naive(
            feats[0], verts, return_packed=True, align_corners=False
        )
        self.assertFalse(torch.allclose(out, out2))
        self.assertTrue(torch.allclose(out2, naive_out2))

    @staticmethod
    def vert_align_with_init(
        num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
    ):
        device = torch.device(device)
        verts_list = []
        faces_list = []
        for _ in range(num_meshes):
            verts = torch.rand(
                (num_verts, 3), dtype=torch.float32, device=device
            )
            faces = torch.randint(
                num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
            )
            verts_list.append(verts)
            faces_list.append(faces)
        meshes = Meshes(verts_list, faces_list)
        feats = TestVertAlign.init_feats(num_meshes, device=device)
        torch.cuda.synchronize()

        def sample_features():
            vert_align(feats, meshes, return_packed=True)
            torch.cuda.synchronize()

        return sample_features