test_cugraph_sageconv.py 1.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# pylint: disable=too-many-arguments, too-many-locals
from collections import OrderedDict
from itertools import product

import dgl
import pytest
import torch
from dgl.nn import CuGraphSAGEConv, SAGEConv

options = OrderedDict(
    {
        "idtype_int": [False, True],
        "max_in_degree": [None, 8],
        "to_block": [False, True],
    }
)


def generate_graph():
    u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])
    v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])
    g = dgl.graph((u, v))
    return g


@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
    device = "cuda:0"
    in_feat, out_feat = 5, 2
    kwargs = {"aggregator_type": "mean"}
    g = generate_graph().to(device)
    if idtype_int:
        g = g.int()
    if to_block:
        g = dgl.to_block(g)
    feat = torch.rand(g.num_src_nodes(), in_feat).to(device)

    torch.manual_seed(0)
    conv1 = SAGEConv(in_feat, out_feat, **kwargs).to(device)

    torch.manual_seed(0)
    conv2 = CuGraphSAGEConv(in_feat, out_feat, **kwargs).to(device)

    with torch.no_grad():
        conv2.linear.weight.data[:, :in_feat] = conv1.fc_neigh.weight.data
        conv2.linear.weight.data[:, in_feat:] = conv1.fc_self.weight.data
        conv2.linear.bias.data[:] = conv1.fc_self.bias.data

    out1 = conv1(g, feat)
    out2 = conv2(g, feat, max_in_degree=max_in_degree)
    assert torch.allclose(out1, out2, atol=1e-06)

    grad_out = torch.rand_like(out1)
    out1.backward(grad_out)
    out2.backward(grad_out)
    assert torch.allclose(
        conv1.fc_neigh.weight.grad,
        conv2.linear.weight.grad[:, :in_feat],
        atol=1e-6,
    )
    assert torch.allclose(
        conv1.fc_self.weight.grad,
        conv2.linear.weight.grad[:, in_feat:],
        atol=1e-6,
    )
    assert torch.allclose(
        conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6
    )