test_cugraph_gatconv.py 1.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# pylint: disable=too-many-arguments, too-many-locals
from collections import OrderedDict
from itertools import product

import dgl
import pytest
import torch
from dgl.nn import CuGraphGATConv, GATConv

options = OrderedDict(
    {
        "idtype_int": [False, True],
        "max_in_degree": [None, 8],
        "num_heads": [1, 3],
        "to_block": [False, True],
    }
)


def generate_graph():
    u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])
    v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])
    g = dgl.graph((u, v))
    return g


@pytest.mark.skip()
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_gatconv_equality(idtype_int, max_in_degree, num_heads, to_block):
    device = "cuda:0"
    in_feat, out_feat = 10, 2
    args = (in_feat, out_feat, num_heads)
    kwargs = {"bias": False}
    g = generate_graph().to(device)
    if idtype_int:
        g = g.int()
    if to_block:
        g = dgl.to_block(g)
    feat = torch.rand(g.num_src_nodes(), in_feat).to(device)

    torch.manual_seed(0)
    conv1 = GATConv(*args, **kwargs, allow_zero_in_degree=True).to(device)
    out1 = conv1(g, feat)

    torch.manual_seed(0)
    conv2 = CuGraphGATConv(*args, **kwargs).to(device)
    dim = num_heads * out_feat
    with torch.no_grad():
        conv2.attn_weights.data[:dim] = conv1.attn_l.data.flatten()
        conv2.attn_weights.data[dim:] = conv1.attn_r.data.flatten()
        conv2.fc.weight.data[:] = conv1.fc.weight.data
    out2 = conv2(g, feat, max_in_degree=max_in_degree)
    assert torch.allclose(out1, out2, atol=1e-6)

    grad_out1 = torch.rand_like(out1)
    grad_out2 = grad_out1.clone().detach()
    out1.backward(grad_out1)
    out2.backward(grad_out2)

    assert torch.allclose(conv1.fc.weight.grad, conv2.fc.weight.grad, atol=1e-6)
    assert torch.allclose(
        torch.cat((conv1.attn_l.grad, conv1.attn_r.grad), dim=0),
        conv2.attn_weights.grad.view(2, num_heads, out_feat),
        atol=1e-6,
    )