Unverified Commit 9e630101 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix and lots of tests (#2650)

parent cf8a3fb3
...@@ -84,7 +84,7 @@ def bmm_maybe_select(A, B, index): ...@@ -84,7 +84,7 @@ def bmm_maybe_select(A, B, index):
return B[index, A, :] return B[index, A, :]
else: else:
BB = nd.take(B, index, axis=0) BB = nd.take(B, index, axis=0)
return nd.batch_dot(A.expand_dims(1), BB).squeeze() return nd.batch_dot(A.expand_dims(1), BB).squeeze(1)
def normalize(x, p=2, axis=1, eps=1e-12): def normalize(x, p=2, axis=1, eps=1e-12):
r"""Performs :math:`L_p` normalization of inputs over specified dimension. r"""Performs :math:`L_p` normalization of inputs over specified dimension.
......
...@@ -238,7 +238,7 @@ class RelGraphConv(nn.Module): ...@@ -238,7 +238,7 @@ class RelGraphConv(nn.Module):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device), etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device)) th.tensor(etypes, device=device))
weight = weight.index_select(0, etypes) weight = weight.index_select(0, etypes)
msg = th.bmm(h.unsqueeze(1), weight).squeeze() msg = th.bmm(h.unsqueeze(1), weight).squeeze(1)
if 'norm' in edges.data: if 'norm' in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
......
...@@ -87,7 +87,7 @@ def bmm_maybe_select(A, B, index): ...@@ -87,7 +87,7 @@ def bmm_maybe_select(A, B, index):
return tf.gather(B, flatidx) return tf.gather(B, flatidx)
else: else:
BB = tf.gather(B, index) BB = tf.gather(B, index)
return tf.squeeze(tf.matmul(tf.expand_dims(A, 1), BB)) return tf.squeeze(tf.matmul(tf.expand_dims(A, 1), BB), 1)
class Identity(layers.Layer): class Identity(layers.Layer):
......
...@@ -20,13 +20,14 @@ def _AXWb(A, X, W, b): ...@@ -20,13 +20,14 @@ def _AXWb(A, X, W, b):
return Y + b.data(X.context) return Y + b.data(X.context)
@parametrize_dtype @parametrize_dtype
def test_graph_conv(idtype): @pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(idtype, out_dim):
g = dgl.from_networkx(nx.path_graph(3)) g = dgl.from_networkx(nx.path_graph(3))
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(transpose=False, ctx=ctx) adj = g.adjacency_matrix(transpose=False, ctx=ctx)
conv = nn.GraphConv(5, 2, norm='none', bias=True) conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -41,7 +42,7 @@ def test_graph_conv(idtype): ...@@ -41,7 +42,7 @@ def test_graph_conv(idtype):
assert len(g.edata) == 0 assert len(g.edata) == 0
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, out_dim)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#3: basic # test#3: basic
...@@ -55,7 +56,7 @@ def test_graph_conv(idtype): ...@@ -55,7 +56,7 @@ def test_graph_conv(idtype):
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, out_dim)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
with autograd.train_mode(): with autograd.train_mode():
...@@ -83,38 +84,40 @@ def test_graph_conv(idtype): ...@@ -83,38 +84,40 @@ def test_graph_conv(idtype):
@pytest.mark.parametrize('norm', ['none', 'both', 'right']) @pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False]) @pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False]) @pytest.mark.parametrize('bias', [False])
def test_graph_conv2(idtype, g, norm, weight, bias): @pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
conv.initialize(ctx=F.ctx()) conv.initialize(ctx=F.ctx())
ext_w = F.randn((5, 2)).as_in_context(F.ctx()) ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
nsrc = ndst = g.number_of_nodes() nsrc = ndst = g.number_of_nodes()
h = F.randn((nsrc, 5)).as_in_context(F.ctx()) h = F.randn((nsrc, 5)).as_in_context(F.ctx())
if weight: if weight:
h_out = conv(g, h) h_out = conv(g, h)
else: else:
h_out = conv(g, h, ext_w) h_out = conv(g, h, ext_w)
assert h_out.shape == (ndst, 2) assert h_out.shape == (ndst, out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right']) @pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False]) @pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False]) @pytest.mark.parametrize('bias', [False])
def test_graph_conv2_bi(idtype, g, norm, weight, bias): @pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
conv.initialize(ctx=F.ctx()) conv.initialize(ctx=F.ctx())
ext_w = F.randn((5, 2)).as_in_context(F.ctx()) ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
nsrc = g.number_of_src_nodes() nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes() ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).as_in_context(F.ctx()) h = F.randn((nsrc, 5)).as_in_context(F.ctx())
h_dst = F.randn((ndst, 2)).as_in_context(F.ctx()) h_dst = F.randn((ndst, out_dim)).as_in_context(F.ctx())
if weight: if weight:
h_out = conv(g, (h, h_dst)) h_out = conv(g, (h, h_dst))
else: else:
h_out = conv(g, (h, h_dst), ext_w) h_out = conv(g, (h, h_dst), ext_w)
assert h_out.shape == (ndst, 2) assert h_out.shape == (ndst, out_dim)
def _S2AXWb(A, N, X, W, b): def _S2AXWb(A, N, X, W, b):
X1 = X * N X1 = X * N
...@@ -128,13 +131,14 @@ def _S2AXWb(A, N, X, W, b): ...@@ -128,13 +131,14 @@ def _S2AXWb(A, N, X, W, b):
return Y + b return Y + b
def test_tagconv(): @pytest.mark.parametrize('out_dim', [1, 2])
def test_tagconv(out_dim):
g = dgl.from_networkx(nx.path_graph(3)).to(F.ctx()) g = dgl.from_networkx(nx.path_graph(3)).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(transpose=False, ctx=ctx) adj = g.adjacency_matrix(transpose=False, ctx=ctx)
norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5) norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5)
conv = nn.TAGConv(5, 2, bias=True) conv = nn.TAGConv(5, out_dim, bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
print(conv) print(conv)
...@@ -148,86 +152,93 @@ def test_tagconv(): ...@@ -148,86 +152,93 @@ def test_tagconv():
assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx))) assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx)))
conv = nn.TAGConv(5, 2) conv = nn.TAGConv(5, out_dim)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#2: basic # test#2: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(g, h0) h1 = conv(g, h0)
assert h1.shape[-1] == 2 assert h1.shape[-1] == out_dim
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_gat_conv(g, idtype): @pytest.mark.parametrize('out_dim', [1, 20])
@pytest.mark.parametrize('num_heads', [1, 5])
def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(10, 20, 5) # n_heads = 5 gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5
gat.initialize(ctx=ctx) gat.initialize(ctx=ctx)
print(gat) print(gat)
feat = F.randn((g.number_of_nodes(), 10)) feat = F.randn((g.number_of_nodes(), 10))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 5, 20) assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
_, a = gat(g, feat, True) _, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), 5, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_gat_conv_bi(g, idtype): @pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(5, 2, 4) gat = nn.GATConv(5, out_dim, num_heads)
gat.initialize(ctx=ctx) gat.initialize(ctx=ctx)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, True) _, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), 4, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv(idtype, g, aggre_type): @pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, out_dim, aggre_type)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
sage.initialize(ctx=ctx) sage.initialize(ctx=ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == out_dim
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi(idtype, g, aggre_type): @pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
dst_dim = 5 if aggre_type != 'gcn' else 10 dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type) sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim))) feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
sage.initialize(ctx=ctx) sage.initialize(ctx=ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 2 assert h.shape[-1] == out_dim
assert h.shape[0] == g.number_of_dst_nodes() assert h.shape[0] == g.number_of_dst_nodes()
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi2(idtype, aggre_type): @pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi2(idtype, aggre_type, out_dim):
# Test the case for graphs without edges # Test the case for graphs without edges
g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}) g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
sage = nn.SAGEConv((3, 3), 2, 'gcn') sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3))) feat = (F.randn((5, 3)), F.randn((3, 3)))
sage.initialize(ctx=ctx) sage.initialize(ctx=ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 2 assert h.shape[-1] == out_dim
assert h.shape[0] == 3 assert h.shape[0] == 3
for aggre_type in ['mean', 'pool']: for aggre_type in ['mean', 'pool']:
sage = nn.SAGEConv((3, 1), 2, aggre_type) sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1))) feat = (F.randn((5, 3)), F.randn((3, 1)))
sage.initialize(ctx=ctx) sage.initialize(ctx=ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 2 assert h.shape[-1] == out_dim
assert h.shape[0] == 3 assert h.shape[0] == 3
def test_gg_conv(): def test_gg_conv():
...@@ -244,18 +255,19 @@ def test_gg_conv(): ...@@ -244,18 +255,19 @@ def test_gg_conv():
h1 = gg_conv(g, h0, etypes) h1 = gg_conv(g, h0, etypes)
assert h1.shape == (20, 20) assert h1.shape == (20, 20)
def test_cheb_conv(): @pytest.mark.parametrize('out_dim', [1, 20])
def test_cheb_conv(out_dim):
g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx()) g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
cheb = nn.ChebConv(10, 20, 3) # k = 3 cheb = nn.ChebConv(10, out_dim, 3) # k = 3
cheb.initialize(ctx=ctx) cheb.initialize(ctx=ctx)
print(cheb) print(cheb)
# test#1: basic # test#1: basic
h0 = F.randn((20, 10)) h0 = F.randn((20, 10))
h1 = cheb(g, h0) h1 = cheb(g, h0)
assert h1.shape == (20, 20) assert h1.shape == (20, out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
...@@ -294,13 +306,14 @@ def test_appnp_conv(): ...@@ -294,13 +306,14 @@ def test_appnp_conv():
h1 = appnp_conv(g, h0) h1 = appnp_conv(g, h0)
assert h1.shape == (20, 10) assert h1.shape == (20, 10)
def test_dense_cheb_conv(): @pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
for k in range(1, 4): for k in range(1, 4):
ctx = F.ctx() ctx = F.ctx()
g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.3)).to(F.ctx()) g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.3)).to(F.ctx())
adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default') adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default')
cheb = nn.ChebConv(5, 2, k) cheb = nn.ChebConv(5, out_dim, k)
dense_cheb = nn.DenseChebConv(5, 2, k) dense_cheb = nn.DenseChebConv(5, out_dim, k)
cheb.initialize(ctx=ctx) cheb.initialize(ctx=ctx)
dense_cheb.initialize(ctx=ctx) dense_cheb.initialize(ctx=ctx)
...@@ -319,12 +332,13 @@ def test_dense_cheb_conv(): ...@@ -319,12 +332,13 @@ def test_dense_cheb_conv():
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none']) @pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_dense_graph_conv(idtype, g, norm_type): @pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_graph_conv(idtype, g, norm_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default') adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default')
conv = nn.GraphConv(5, 2, norm=norm_type, bias=True) conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True) dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
dense_conv.initialize(ctx=ctx) dense_conv.initialize(ctx=ctx)
dense_conv.weight.set_data( dense_conv.weight.set_data(
...@@ -338,12 +352,13 @@ def test_dense_graph_conv(idtype, g, norm_type): ...@@ -338,12 +352,13 @@ def test_dense_graph_conv(idtype, g, norm_type):
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'bipartite', 'block-bipartite']))
def test_dense_sage_conv(idtype, g): @pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_sage_conv(idtype, g, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default') adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default')
sage = nn.SAGEConv(5, 2, 'gcn') sage = nn.SAGEConv(5, out_dim, 'gcn')
dense_sage = nn.DenseSAGEConv(5, 2) dense_sage = nn.DenseSAGEConv(5, out_dim)
sage.initialize(ctx=ctx) sage.initialize(ctx=ctx)
dense_sage.initialize(ctx=ctx) dense_sage.initialize(ctx=ctx)
dense_sage.fc.weight.set_data( dense_sage.fc.weight.set_data(
...@@ -364,30 +379,32 @@ def test_dense_sage_conv(idtype, g): ...@@ -364,30 +379,32 @@ def test_dense_sage_conv(idtype, g):
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_edge_conv(g, idtype): @pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
edge_conv = nn.EdgeConv(5, 2) edge_conv = nn.EdgeConv(5, out_dim)
edge_conv.initialize(ctx=ctx) edge_conv.initialize(ctx=ctx)
print(edge_conv) print(edge_conv)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_nodes(), out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_edge_conv_bi(g, idtype): @pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
edge_conv = nn.EdgeConv(5, 2) edge_conv = nn.EdgeConv(5, out_dim)
edge_conv.initialize(ctx=ctx) edge_conv.initialize(ctx=ctx)
print(edge_conv) print(edge_conv)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_src_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
x0 = F.randn((g.number_of_dst_nodes(), 5)) x0 = F.randn((g.number_of_dst_nodes(), 5))
h1 = edge_conv(g, (h0, x0)) h1 = edge_conv(g, (h0, x0))
assert h1.shape == (g.number_of_dst_nodes(), 2) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
...@@ -475,19 +492,20 @@ def test_nn_conv_bi(g, idtype): ...@@ -475,19 +492,20 @@ def test_nn_conv_bi(g, idtype):
h1 = nn_conv(g, (h0, hd), etypes) h1 = nn_conv(g, (h0, hd), etypes)
assert h1.shape == (g.number_of_dst_nodes(), 2) assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_sg_conv(): @pytest.mark.parametrize('out_dim', [1, 2])
def test_sg_conv(out_dim):
g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx()) g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
ctx = F.ctx() ctx = F.ctx()
sgc = nn.SGConv(5, 2, 2) sgc = nn.SGConv(5, out_dim, 2)
sgc.initialize(ctx=ctx) sgc.initialize(ctx=ctx)
print(sgc) print(sgc)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sgc(g, h0) h1 = sgc(g, h0)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_nodes(), out_dim)
def test_set2set(): def test_set2set():
g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx()) g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
...@@ -577,7 +595,8 @@ def test_simple_pool(): ...@@ -577,7 +595,8 @@ def test_simple_pool():
h1 = sort_pool(bg, h0) h1 = sort_pool(bg, h0)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def test_rgcn(): @pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
ctx = F.ctx() ctx = F.ctx()
etype = [] etype = []
g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)).to(F.ctx()) g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)).to(F.ctx())
...@@ -587,7 +606,6 @@ def test_rgcn(): ...@@ -587,7 +606,6 @@ def test_rgcn():
etype.append(i % 5) etype.append(i % 5)
B = 2 B = 2
I = 10 I = 10
O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis.initialize(ctx=ctx) rgc_basis.initialize(ctx=ctx)
...@@ -596,12 +614,13 @@ def test_rgcn(): ...@@ -596,12 +614,13 @@ def test_rgcn():
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) if O % B == 0:
rgc_bdd.initialize(ctx=ctx) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
h = nd.random.randn(100, I, ctx=ctx) rgc_bdd.initialize(ctx=ctx)
r = nd.array(etype, ctx=ctx) h = nd.random.randn(100, I, ctx=ctx)
h_new = rgc_bdd(g, h, r) r = nd.array(etype, ctx=ctx)
assert list(h_new.shape) == [100, O] h_new = rgc_bdd(g, h, r)
assert list(h_new.shape) == [100, O]
# with norm # with norm
norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx) norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)
...@@ -613,12 +632,13 @@ def test_rgcn(): ...@@ -613,12 +632,13 @@ def test_rgcn():
h_new = rgc_basis(g, h, r, norm) h_new = rgc_basis(g, h, r, norm)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) if O % B == 0:
rgc_bdd.initialize(ctx=ctx) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
h = nd.random.randn(100, I, ctx=ctx) rgc_bdd.initialize(ctx=ctx)
r = nd.array(etype, ctx=ctx) h = nd.random.randn(100, I, ctx=ctx)
h_new = rgc_bdd(g, h, r, norm) r = nd.array(etype, ctx=ctx)
assert list(h_new.shape) == [100, O] h_new = rgc_bdd(g, h, r, norm)
assert list(h_new.shape) == [100, O]
# id input # id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
......
This diff is collapsed.
...@@ -18,12 +18,13 @@ def _AXWb(A, X, W, b): ...@@ -18,12 +18,13 @@ def _AXWb(A, X, W, b):
Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape) Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
return Y + b return Y + b
def test_graph_conv(): @pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(out_dim):
g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx()) g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx))) adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))
conv = nn.GraphConv(5, 2, norm='none', bias=True) conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
# conv = conv # conv = conv
print(conv) print(conv)
# test#1: basic # test#1: basic
...@@ -39,7 +40,7 @@ def test_graph_conv(): ...@@ -39,7 +40,7 @@ def test_graph_conv():
assert len(g.edata) == 0 assert len(g.edata) == 0
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, out_dim)
# conv = conv # conv = conv
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -52,7 +53,7 @@ def test_graph_conv(): ...@@ -52,7 +53,7 @@ def test_graph_conv():
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, out_dim)
# conv = conv # conv = conv
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -76,38 +77,40 @@ def test_graph_conv(): ...@@ -76,38 +77,40 @@ def test_graph_conv():
@pytest.mark.parametrize('norm', ['none', 'both', 'right']) @pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False]) @pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False]) @pytest.mark.parametrize('bias', [True, False])
def test_graph_conv2(idtype, g, norm, weight, bias): @pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
ext_w = F.randn((5, 2)) ext_w = F.randn((5, out_dim))
nsrc = g.number_of_src_nodes() nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes() ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)) h = F.randn((nsrc, 5))
h_dst = F.randn((ndst, 2)) h_dst = F.randn((ndst, out_dim))
if weight: if weight:
h_out = conv(g, h) h_out = conv(g, h)
else: else:
h_out = conv(g, h, weight=ext_w) h_out = conv(g, h, weight=ext_w)
assert h_out.shape == (ndst, 2) assert h_out.shape == (ndst, out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right']) @pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False]) @pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False]) @pytest.mark.parametrize('bias', [True, False])
def test_graph_conv2_bi(idtype, g, norm, weight, bias): @pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
ext_w = F.randn((5, 2)) ext_w = F.randn((5, out_dim))
nsrc = g.number_of_src_nodes() nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes() ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)) h = F.randn((nsrc, 5))
h_dst = F.randn((ndst, 2)) h_dst = F.randn((ndst, out_dim))
if weight: if weight:
h_out = conv(g, (h, h_dst)) h_out = conv(g, (h, h_dst))
else: else:
h_out = conv(g, (h, h_dst), weight=ext_w) h_out = conv(g, (h, h_dst), weight=ext_w)
assert h_out.shape == (ndst, 2) assert h_out.shape == (ndst, out_dim)
def test_simple_pool(): def test_simple_pool():
ctx = F.ctx() ctx = F.ctx()
...@@ -179,7 +182,8 @@ def test_glob_att_pool(): ...@@ -179,7 +182,8 @@ def test_glob_att_pool():
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
def test_rgcn(): @pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
etype = [] etype = []
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx()) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
# 5 etypes # 5 etypes
...@@ -188,7 +192,6 @@ def test_rgcn(): ...@@ -188,7 +192,6 @@ def test_rgcn():
etype.append(i % 5) etype.append(i % 5)
B = 2 B = 2
I = 10 I = 10
O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
...@@ -203,17 +206,18 @@ def test_rgcn(): ...@@ -203,17 +206,18 @@ def test_rgcn():
assert list(h_new_low.shape) == [100, O] assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low) assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) if O % B == 0:
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd_low.weight = rgc_bdd.weight rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
rgc_bdd_low.loop_weight = rgc_bdd.loop_weight rgc_bdd_low.weight = rgc_bdd.weight
h = tf.random.normal((100, I)) rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
r = tf.constant(etype) h = tf.random.normal((100, I))
h_new = rgc_bdd(g, h, r) r = tf.constant(etype)
h_new_low = rgc_bdd_low(g, h, r) h_new = rgc_bdd(g, h, r)
assert list(h_new.shape) == [100, O] h_new_low = rgc_bdd_low(g, h, r)
assert list(h_new_low.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert F.allclose(h_new, h_new_low) assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# with norm # with norm
norm = tf.zeros((g.number_of_edges(), 1)) norm = tf.zeros((g.number_of_edges(), 1))
...@@ -231,17 +235,18 @@ def test_rgcn(): ...@@ -231,17 +235,18 @@ def test_rgcn():
assert list(h_new_low.shape) == [100, O] assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low) assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) if O % B == 0:
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd_low.weight = rgc_bdd.weight rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
rgc_bdd_low.loop_weight = rgc_bdd.loop_weight rgc_bdd_low.weight = rgc_bdd.weight
h = tf.random.normal((100, I)) rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
r = tf.constant(etype) h = tf.random.normal((100, I))
h_new = rgc_bdd(g, h, r, norm) r = tf.constant(etype)
h_new_low = rgc_bdd_low(g, h, r, norm) h_new = rgc_bdd(g, h, r, norm)
assert list(h_new.shape) == [100, O] h_new_low = rgc_bdd_low(g, h, r, norm)
assert list(h_new_low.shape) == [100, O] assert list(h_new.shape) == [100, O]
assert F.allclose(h_new, h_new_low) assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# id input # id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
...@@ -259,87 +264,94 @@ def test_rgcn(): ...@@ -259,87 +264,94 @@ def test_rgcn():
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_gat_conv(g, idtype): @pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(5, 2, 4) gat = nn.GATConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2) assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_gat_conv_bi(g, idtype): @pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(5, 2, 4) gat = nn.GATConv(5, out_dim, num_heads)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv(idtype, g, aggre_type): @pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, out_dim, aggre_type)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == out_dim
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi(idtype, g, aggre_type): @pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, 10, aggre_type)
dst_dim = 5 if aggre_type != 'gcn' else 10 dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type) sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim))) feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 2 assert h.shape[-1] == out_dim
assert h.shape[0] == g.number_of_dst_nodes() assert h.shape[0] == g.number_of_dst_nodes()
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi_empty(idtype, aggre_type): @pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
# Test the case for graphs without edges # Test the case for graphs without edges
g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx()) g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx())
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv((3, 3), 2, 'gcn') sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3))) feat = (F.randn((5, 3)), F.randn((3, 3)))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 2 assert h.shape[-1] == out_dim
assert h.shape[0] == 3 assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']: for aggre_type in ['mean', 'pool', 'lstm']:
sage = nn.SAGEConv((3, 1), 2, aggre_type) sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1))) feat = (F.randn((5, 3)), F.randn((3, 1)))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 2 assert h.shape[-1] == out_dim
assert h.shape[0] == 3 assert h.shape[0] == 3
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgc_conv(g, idtype): @pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx) g = g.astype(idtype).to(ctx)
# not cached # not cached
sgc = nn.SGConv(5, 10, 3) sgc = nn.SGConv(5, out_dim, 3)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
h = sgc(g, feat) h = sgc(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == out_dim
# cached # cached
sgc = nn.SGConv(5, 10, 3, True) sgc = nn.SGConv(5, out_dim, 3, True)
h_0 = sgc(g, feat) h_0 = sgc(g, feat)
h_1 = sgc(g, feat + 1) h_1 = sgc(g, feat + 1)
assert F.allclose(h_0, h_1) assert F.allclose(h_0, h_1)
assert h_0.shape[-1] == 10 assert h_0.shape[-1] == out_dim
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
...@@ -463,21 +475,22 @@ def test_hetero_conv(agg, idtype): ...@@ -463,21 +475,22 @@ def test_hetero_conv(agg, idtype):
assert mod3.carg2 == 1 assert mod3.carg2 == 1
def test_dense_cheb_conv(): @pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
for k in range(3, 4): for k in range(3, 4):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42)) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42))
g = g.to(ctx) g = g.to(ctx)
adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx))) adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))
cheb = nn.ChebConv(5, 2, k, None, bias=True) cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
dense_cheb = nn.DenseChebConv(5, 2, k, bias=True) dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
# init cheb modules # init cheb modules
feat = F.ones((100, 5)) feat = F.ones((100, 5))
out_cheb = cheb(g, feat, [2.0]) out_cheb = cheb(g, feat, [2.0])
dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, 2)) dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
if cheb.linear.bias is not None: if cheb.linear.bias is not None:
dense_cheb.bias = cheb.linear.bias dense_cheb.bias = cheb.linear.bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment