"vscode:/vscode.git/clone" did not exist on "1598a579580553d7007a1cc7c6a507250fd52d70"
test_cugraph_relgraphconv.py 3.06 KB
Newer Older
1
2
3
4
5
6
import pytest
import torch
import dgl
from dgl.nn import CuGraphRelGraphConv
from dgl.nn import RelGraphConv

7
8
# TODO(tingyu66): Re-enable the following tests after updating cuGraph CI image.
use_longs = [False, True]
9
max_in_degrees = [None, 8]
10
regularizers = [None, "basis"]
11
12
13
14
15
16
17
18
19
20
21
device = "cuda"


def generate_graph():
    u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])
    v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])
    g = dgl.graph((u, v))
    num_rels = 3
    g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),))
    return g

22
23
@pytest.mark.skip()
@pytest.mark.parametrize('use_long', use_longs)
24
25
@pytest.mark.parametrize('max_in_degree', max_in_degrees)
@pytest.mark.parametrize("regularizer", regularizers)
26
def test_full_graph(use_long, max_in_degree, regularizer):
27
28
29
30
31
32
33
34
    in_feat, out_feat, num_rels, num_bases = 10, 2, 3, 2
    kwargs = {
        "num_bases": num_bases,
        "regularizer": regularizer,
        "bias": False,
        "self_loop": False,
    }
    g = generate_graph().to(device)
35
36
37
38
    if use_long:
        g = g.long()
    else:
        g = g.int()
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    feat = torch.ones(g.num_nodes(), in_feat).to(device)

    torch.manual_seed(0)
    conv1 = RelGraphConv(in_feat, out_feat, num_rels, **kwargs).to(device)

    torch.manual_seed(0)
    conv2 = CuGraphRelGraphConv(
        in_feat, out_feat, num_rels, max_in_degree=max_in_degree, **kwargs
    ).to(device)

    out1 = conv1(g, feat, g.edata[dgl.ETYPE])
    out2 = conv2(g, feat, g.edata[dgl.ETYPE])

    assert torch.allclose(out1, out2, atol=1e-06)

    grad_out = torch.rand_like(out1)
    out1.backward(grad_out)
    out2.backward(grad_out)
    assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad, atol=1e-6)
    if regularizer is not None:
        assert torch.allclose(
            conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6
        )

63
@pytest.mark.skip()
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@pytest.mark.parametrize('max_in_degree', max_in_degrees)
@pytest.mark.parametrize("regularizer", regularizers)
def test_mfg(max_in_degree, regularizer):
    in_feat, out_feat, num_rels, num_bases = 10, 2, 3, 2
    kwargs = {
        "num_bases": num_bases,
        "regularizer": regularizer,
        "bias": False,
        "self_loop": False,
    }
    g = generate_graph().to(device)
    block = dgl.to_block(g)
    feat = torch.ones(g.num_nodes(), in_feat).to(device)

    torch.manual_seed(0)
    conv1 = RelGraphConv(in_feat, out_feat, num_rels, **kwargs).to(device)

    torch.manual_seed(0)
    conv2 = CuGraphRelGraphConv(
        in_feat, out_feat, num_rels, max_in_degree=max_in_degree, **kwargs
    ).to(device)

    out1 = conv1(block, feat[block.srcdata[dgl.NID]], block.edata[dgl.ETYPE])
    out2 = conv2(block, feat[block.srcdata[dgl.NID]], block.edata[dgl.ETYPE])

    assert torch.allclose(out1, out2, atol=1e-06)

    grad_out = torch.rand_like(out1)
    out1.backward(grad_out)
    out2.backward(grad_out)
    assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad, atol=1e-6)
    if regularizer is not None:
        assert torch.allclose(
            conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6
        )