test_graph_conv.py 7.24 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Patrick Labatut's avatar
Patrick Labatut committed
2
3
4
5
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
facebook-github-bot's avatar
facebook-github-bot committed
6
7

import unittest
8

facebook-github-bot's avatar
facebook-github-bot committed
9
10
11
import torch
import torch.nn as nn
from pytorch3d import _C
12
from pytorch3d.ops.graph_conv import gather_scatter, gather_scatter_python, GraphConv
facebook-github-bot's avatar
facebook-github-bot committed
13
14
15
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import ico_sphere

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
16
17
from .common_testing import get_random_cuda_device, TestCaseMixin

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
18
19

class TestGraphConv(TestCaseMixin, unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
20
21
    def test_undirected(self):
        dtype = torch.float32
Nikhila Ravi's avatar
Nikhila Ravi committed
22
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        verts = torch.tensor(
            [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
        )
        edges = torch.tensor([[0, 1], [0, 2]], device=device)
        w0 = torch.tensor([[1, 1, 1]], dtype=dtype, device=device)
        w1 = torch.tensor([[-1, -1, -1]], dtype=dtype, device=device)

        expected_y = torch.tensor(
            [
                [1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9],
                [4 + 5 + 6 - 1 - 2 - 3],
                [7 + 8 + 9 - 1 - 2 - 3],
            ],
            dtype=dtype,
            device=device,
        )

        conv = GraphConv(3, 1, directed=False).to(device)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()
        conv.w1.weight.data.copy_(w1)
        conv.w1.bias.data.zero_()

        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
47
        self.assertClose(y, expected_y)
facebook-github-bot's avatar
facebook-github-bot committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    def test_no_edges(self):
        dtype = torch.float32
        verts = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)
        edges = torch.zeros(0, 2, dtype=torch.int64)
        w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
        expected_y = torch.tensor(
            [[1 - 2 - 2 * 3], [4 - 5 - 2 * 6], [7 - 8 - 2 * 9]], dtype=dtype
        )
        conv = GraphConv(3, 1).to(dtype)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()

        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
62
        self.assertClose(y, expected_y)
facebook-github-bot's avatar
facebook-github-bot committed
63
64
65
66
67
68

    def test_no_verts_and_edges(self):
        dtype = torch.float32
        verts = torch.tensor([], dtype=dtype, requires_grad=True)
        edges = torch.tensor([], dtype=dtype)
        w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69

facebook-github-bot's avatar
facebook-github-bot committed
70
71
72
73
        conv = GraphConv(3, 1).to(dtype)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()
        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
74
75
76
77
78
79
80
81
        self.assertClose(y, torch.zeros((0, 1)))
        self.assertTrue(y.requires_grad)

        conv2 = GraphConv(3, 2).to(dtype)
        conv2.w0.weight.data.copy_(w0.repeat(2, 1))
        conv2.w0.bias.data.zero_()
        y = conv2(verts, edges)
        self.assertClose(y, torch.zeros((0, 2)))
facebook-github-bot's avatar
facebook-github-bot committed
82
83
84
85
86
87
88
89
90
91
        self.assertTrue(y.requires_grad)

    def test_directed(self):
        dtype = torch.float32
        verts = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)
        edges = torch.tensor([[0, 1], [0, 2]])
        w0 = torch.tensor([[1, 1, 1]], dtype=dtype)
        w1 = torch.tensor([[-1, -1, -1]], dtype=dtype)

        expected_y = torch.tensor(
92
            [[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]], dtype=dtype
facebook-github-bot's avatar
facebook-github-bot committed
93
94
95
96
97
98
99
100
101
        )

        conv = GraphConv(3, 1, directed=True).to(dtype)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()
        conv.w1.weight.data.copy_(w1)
        conv.w1.bias.data.zero_()

        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
102
        self.assertClose(y, expected_y)
facebook-github-bot's avatar
facebook-github-bot committed
103
104

    def test_backward(self):
Nikhila Ravi's avatar
Nikhila Ravi committed
105
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
106
107
108
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
109
110
        verts_cpu = verts.clone()
        edges_cpu = edges.clone()
facebook-github-bot's avatar
facebook-github-bot committed
111
112
113
        verts_cuda = verts.clone().to(device)
        edges_cuda = edges.clone().to(device)
        verts.requires_grad = True
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
114
        verts_cpu.requires_grad = True
facebook-github-bot's avatar
facebook-github-bot committed
115
116
117
        verts_cuda.requires_grad = True

        neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
118
        neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
facebook-github-bot's avatar
facebook-github-bot committed
119
        neighbor_sums = gather_scatter_python(verts, edges, False)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
120
        randoms = torch.rand_like(neighbor_sums)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
121
        (neighbor_sums_cuda * randoms.to(device)).sum().backward()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
122
123
        (neighbor_sums_cpu * randoms).sum().backward()
        (neighbor_sums * randoms).sum().backward()
facebook-github-bot's avatar
facebook-github-bot committed
124

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
125
126
        self.assertClose(verts.grad, verts_cuda.grad.cpu())
        self.assertClose(verts.grad, verts_cpu.grad)
facebook-github-bot's avatar
facebook-github-bot committed
127
128
129
130
131
132

    def test_repr(self):
        conv = GraphConv(32, 64, directed=True)
        self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")

    def test_cpu_cuda_tensor_error(self):
Nikhila Ravi's avatar
Nikhila Ravi committed
133
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
134
        verts = torch.tensor(
135
            [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
facebook-github-bot's avatar
facebook-github-bot committed
136
137
138
139
140
        )
        edges = torch.tensor([[0, 1], [0, 2]])
        conv = GraphConv(3, 1, directed=True).to(torch.float32)
        with self.assertRaises(Exception) as err:
            conv(verts, edges)
141
        self.assertTrue("tensors must be on the same device." in str(err.exception))
facebook-github-bot's avatar
facebook-github-bot committed
142
143
144
145
146
147
148

    def test_gather_scatter(self):
        """
        Check gather_scatter cuda and python versions give the same results.
        Check that gather_scatter cuda version throws an error if cpu tensors
        are given as input.
        """
Nikhila Ravi's avatar
Nikhila Ravi committed
149
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
150
151
152
153
154
155
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        w0 = nn.Linear(3, 1)
        input = w0(verts)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
156
157
        # undirected
        output_python = gather_scatter_python(input, edges, False)
facebook-github-bot's avatar
facebook-github-bot committed
158
159
160
        output_cuda = _C.gather_scatter(
            input.to(device=device), edges.to(device=device), False, False
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
161
162
163
164
        self.assertClose(output_cuda.cpu(), output_python)

        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
        self.assertClose(output_cpu, output_python)
facebook-github-bot's avatar
facebook-github-bot committed
165
166

        # directed
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
167
        output_python = gather_scatter_python(input, edges, True)
facebook-github-bot's avatar
facebook-github-bot committed
168
169
170
        output_cuda = _C.gather_scatter(
            input.to(device=device), edges.to(device=device), True, False
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
171
172
173
        self.assertClose(output_cuda.cpu(), output_python)
        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
        self.assertClose(output_cpu, output_python)
facebook-github-bot's avatar
facebook-github-bot committed
174
175
176
177
178
179
180
181
182
183
184

    @staticmethod
    def graph_conv_forward_backward(
        gconv_dim,
        num_meshes,
        num_verts,
        num_faces,
        directed: bool,
        backend: str = "cuda",
    ):
        device = torch.device("cuda") if backend == "cuda" else "cpu"
185
        verts_list = torch.tensor(num_verts * [[0.11, 0.22, 0.33]], device=device).view(
facebook-github-bot's avatar
facebook-github-bot committed
186
187
            -1, 3
        )
188
        faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(-1, 3)
facebook-github-bot's avatar
facebook-github-bot committed
189
190
191
192
193
194
195
        meshes = Meshes(num_meshes * [verts_list], num_meshes * [faces_list])
        gconv = GraphConv(gconv_dim, gconv_dim, directed=directed)
        gconv.to(device)
        edges = meshes.edges_packed()
        total_verts = meshes.verts_packed().shape[0]

        # Features.
196
        x = torch.randn(total_verts, gconv_dim, device=device, requires_grad=True)
facebook-github-bot's avatar
facebook-github-bot committed
197
198
199
200
201
202
203
204
        torch.cuda.synchronize()

        def run_graph_conv():
            y1 = gconv(x, edges)
            y1.sum().backward()
            torch.cuda.synchronize()

        return run_graph_conv