test_graph_conv.py 7.11 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
4

facebook-github-bot's avatar
facebook-github-bot committed
5
6
import torch
import torch.nn as nn
Nikhila Ravi's avatar
Nikhila Ravi committed
7
from common_testing import TestCaseMixin, get_random_cuda_device
facebook-github-bot's avatar
facebook-github-bot committed
8
from pytorch3d import _C
9
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
facebook-github-bot's avatar
facebook-github-bot committed
10
11
12
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import ico_sphere

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
14

class TestGraphConv(TestCaseMixin, unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
15
16
    def test_undirected(self):
        dtype = torch.float32
Nikhila Ravi's avatar
Nikhila Ravi committed
17
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        verts = torch.tensor(
            [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
        )
        edges = torch.tensor([[0, 1], [0, 2]], device=device)
        w0 = torch.tensor([[1, 1, 1]], dtype=dtype, device=device)
        w1 = torch.tensor([[-1, -1, -1]], dtype=dtype, device=device)

        expected_y = torch.tensor(
            [
                [1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9],
                [4 + 5 + 6 - 1 - 2 - 3],
                [7 + 8 + 9 - 1 - 2 - 3],
            ],
            dtype=dtype,
            device=device,
        )

        conv = GraphConv(3, 1, directed=False).to(device)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()
        conv.w1.weight.data.copy_(w1)
        conv.w1.bias.data.zero_()

        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
42
        self.assertClose(y, expected_y)
facebook-github-bot's avatar
facebook-github-bot committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    def test_no_edges(self):
        dtype = torch.float32
        verts = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)
        edges = torch.zeros(0, 2, dtype=torch.int64)
        w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
        expected_y = torch.tensor(
            [[1 - 2 - 2 * 3], [4 - 5 - 2 * 6], [7 - 8 - 2 * 9]], dtype=dtype
        )
        conv = GraphConv(3, 1).to(dtype)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()

        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
57
        self.assertClose(y, expected_y)
facebook-github-bot's avatar
facebook-github-bot committed
58
59
60
61
62
63

    def test_no_verts_and_edges(self):
        dtype = torch.float32
        verts = torch.tensor([], dtype=dtype, requires_grad=True)
        edges = torch.tensor([], dtype=dtype)
        w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
64

facebook-github-bot's avatar
facebook-github-bot committed
65
66
67
68
        conv = GraphConv(3, 1).to(dtype)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()
        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
70
71
72
73
74
75
76
        self.assertClose(y, torch.zeros((0, 1)))
        self.assertTrue(y.requires_grad)

        conv2 = GraphConv(3, 2).to(dtype)
        conv2.w0.weight.data.copy_(w0.repeat(2, 1))
        conv2.w0.bias.data.zero_()
        y = conv2(verts, edges)
        self.assertClose(y, torch.zeros((0, 2)))
facebook-github-bot's avatar
facebook-github-bot committed
77
78
79
80
81
82
83
84
85
86
        self.assertTrue(y.requires_grad)

    def test_directed(self):
        dtype = torch.float32
        verts = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)
        edges = torch.tensor([[0, 1], [0, 2]])
        w0 = torch.tensor([[1, 1, 1]], dtype=dtype)
        w1 = torch.tensor([[-1, -1, -1]], dtype=dtype)

        expected_y = torch.tensor(
87
            [[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]], dtype=dtype
facebook-github-bot's avatar
facebook-github-bot committed
88
89
90
91
92
93
94
95
96
        )

        conv = GraphConv(3, 1, directed=True).to(dtype)
        conv.w0.weight.data.copy_(w0)
        conv.w0.bias.data.zero_()
        conv.w1.weight.data.copy_(w1)
        conv.w1.bias.data.zero_()

        y = conv(verts, edges)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
97
        self.assertClose(y, expected_y)
facebook-github-bot's avatar
facebook-github-bot committed
98
99

    def test_backward(self):
Nikhila Ravi's avatar
Nikhila Ravi committed
100
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
101
102
103
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
104
105
        verts_cpu = verts.clone()
        edges_cpu = edges.clone()
facebook-github-bot's avatar
facebook-github-bot committed
106
107
108
        verts_cuda = verts.clone().to(device)
        edges_cuda = edges.clone().to(device)
        verts.requires_grad = True
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
109
        verts_cpu.requires_grad = True
facebook-github-bot's avatar
facebook-github-bot committed
110
111
112
        verts_cuda.requires_grad = True

        neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
113
        neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
facebook-github-bot's avatar
facebook-github-bot committed
114
        neighbor_sums = gather_scatter_python(verts, edges, False)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
115
        randoms = torch.rand_like(neighbor_sums)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
116
        (neighbor_sums_cuda * randoms.to(device)).sum().backward()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
117
118
        (neighbor_sums_cpu * randoms).sum().backward()
        (neighbor_sums * randoms).sum().backward()
facebook-github-bot's avatar
facebook-github-bot committed
119

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
120
121
        self.assertClose(verts.grad, verts_cuda.grad.cpu())
        self.assertClose(verts.grad, verts_cpu.grad)
facebook-github-bot's avatar
facebook-github-bot committed
122
123
124
125
126
127

    def test_repr(self):
        conv = GraphConv(32, 64, directed=True)
        self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")

    def test_cpu_cuda_tensor_error(self):
Nikhila Ravi's avatar
Nikhila Ravi committed
128
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
129
        verts = torch.tensor(
130
            [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
facebook-github-bot's avatar
facebook-github-bot committed
131
132
133
134
135
        )
        edges = torch.tensor([[0, 1], [0, 2]])
        conv = GraphConv(3, 1, directed=True).to(torch.float32)
        with self.assertRaises(Exception) as err:
            conv(verts, edges)
136
        self.assertTrue("tensors must be on the same device." in str(err.exception))
facebook-github-bot's avatar
facebook-github-bot committed
137
138
139
140
141
142
143

    def test_gather_scatter(self):
        """
        Check gather_scatter cuda and python versions give the same results.
        Check that gather_scatter cuda version throws an error if cpu tensors
        are given as input.
        """
Nikhila Ravi's avatar
Nikhila Ravi committed
144
        device = get_random_cuda_device()
facebook-github-bot's avatar
facebook-github-bot committed
145
146
147
148
149
150
        mesh = ico_sphere()
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()
        w0 = nn.Linear(3, 1)
        input = w0(verts)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
151
152
        # undirected
        output_python = gather_scatter_python(input, edges, False)
facebook-github-bot's avatar
facebook-github-bot committed
153
154
155
        output_cuda = _C.gather_scatter(
            input.to(device=device), edges.to(device=device), False, False
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
156
157
158
159
        self.assertClose(output_cuda.cpu(), output_python)

        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
        self.assertClose(output_cpu, output_python)
facebook-github-bot's avatar
facebook-github-bot committed
160
161

        # directed
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
162
        output_python = gather_scatter_python(input, edges, True)
facebook-github-bot's avatar
facebook-github-bot committed
163
164
165
        output_cuda = _C.gather_scatter(
            input.to(device=device), edges.to(device=device), True, False
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
166
167
168
        self.assertClose(output_cuda.cpu(), output_python)
        output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
        self.assertClose(output_cpu, output_python)
facebook-github-bot's avatar
facebook-github-bot committed
169
170
171
172
173
174
175
176
177
178
179

    @staticmethod
    def graph_conv_forward_backward(
        gconv_dim,
        num_meshes,
        num_verts,
        num_faces,
        directed: bool,
        backend: str = "cuda",
    ):
        device = torch.device("cuda") if backend == "cuda" else "cpu"
180
        verts_list = torch.tensor(num_verts * [[0.11, 0.22, 0.33]], device=device).view(
facebook-github-bot's avatar
facebook-github-bot committed
181
182
            -1, 3
        )
183
        faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(-1, 3)
facebook-github-bot's avatar
facebook-github-bot committed
184
185
186
187
188
189
190
        meshes = Meshes(num_meshes * [verts_list], num_meshes * [faces_list])
        gconv = GraphConv(gconv_dim, gconv_dim, directed=directed)
        gconv.to(device)
        edges = meshes.edges_packed()
        total_verts = meshes.verts_packed().shape[0]

        # Features.
191
        x = torch.randn(total_verts, gconv_dim, device=device, requires_grad=True)
facebook-github-bot's avatar
facebook-github-bot committed
192
193
194
195
196
197
198
199
        torch.cuda.synchronize()

        def run_graph_conv():
            y1 = gconv(x, edges)
            y1.sum().backward()
            torch.cuda.synchronize()

        return run_graph_conv