test_gatedgcnconv.py 715 Bytes
Newer Older
paoxiaode's avatar
paoxiaode committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import io

import backend as F

import dgl.nn.pytorch as nn
import pytest
from utils import parametrize_idtype
from utils.graph_cases import get_cases

tmp_buffer = io.BytesIO()


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
def test_gatedgcn_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gatedgcnconv = nn.GatedGCNConv(10, 10, 5)
    feat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 10))
    gatedgcnconv = gatedgcnconv.to(ctx)

    h, edge_h = gatedgcnconv(g, feat, efeat)
    # current we only do shape check
    assert h.shape == (g.number_of_dst_nodes(), 5)
    assert edge_h.shape == (g.number_of_edges(), 5)