Unverified Commit 650f6ee1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[NN] Add commonly used GNN models from examples to dgl.nn modules. (#748)

* gat

* upd

* upd sage

* upd

* upd

* upd

* upd

* upd

* add gmmconv

* upd ggnn

* upd

* upd

* upd

* upd

* add citation examples

* add README

* fix cheb

* improve doc

* formula

* upd

* trigger

* lint

* lint

* upd

* add test for transform

* add test

* check

* upd

* improve doc

* shape check

* upd

* densechebconv, currently not correct (?)

* fix cheb

* fix

* upd

* upd sgc-reddit

* upd

* trigger
parent 8079d986
"""MXNet modules for graph global pooling.""" """MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235 # pylint: disable= no-member, arguments-differ, invalid-name, W0235
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
......
This diff is collapsed.
"""Torch modules for graph global pooling.""" """Torch modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235 # pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
...@@ -178,17 +178,6 @@ class GlobalAttentionPooling(nn.Module): ...@@ -178,17 +178,6 @@ class GlobalAttentionPooling(nn.Module):
super(GlobalAttentionPooling, self).__init__() super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn self.gate_nn = gate_nn
self.feat_nn = feat_nn self.feat_nn = feat_nn
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
for p in self.gate_nn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if self.feat_nn:
for p in self.feat_nn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat, graph): def forward(self, feat, graph):
r"""Compute global attention pooling. r"""Compute global attention pooling.
......
"""Module for graph transformation methods.""" """Module for graph transformation utilities."""
import numpy as np
from scipy import sparse
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .graph import DGLGraph from .graph import DGLGraph
from .batched_graph import BatchedDGLGraph from .graph_index import from_coo
from .batched_graph import BatchedDGLGraph, unbatch
from .backend import asnumpy, tensor
__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected'] __all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max']
def line_graph(g, backtracking=True, shared=False): def line_graph(g, backtracking=True, shared=False):
...@@ -12,6 +19,7 @@ def line_graph(g, backtracking=True, shared=False): ...@@ -12,6 +19,7 @@ def line_graph(g, backtracking=True, shared=False):
Parameters Parameters
---------- ----------
g : dgl.DGLGraph g : dgl.DGLGraph
The input graph.
backtracking : bool, optional backtracking : bool, optional
Whether the returned line graph is backtracking. Whether the returned line graph is backtracking.
shared : bool, optional shared : bool, optional
...@@ -26,6 +34,88 @@ def line_graph(g, backtracking=True, shared=False): ...@@ -26,6 +34,88 @@ def line_graph(g, backtracking=True, shared=False):
node_frame = g._edge_frame if shared else None node_frame = g._edge_frame if shared else None
return DGLGraph(graph_data, node_frame) return DGLGraph(graph_data, node_frame)
def khop_adj(g, k):
"""Return the matrix of :math:`A^k` where :math:`A` is the adjacency matrix of :math:`g`,
where a row represents the destination and a column represents the source.
Parameters
----------
g : dgl.DGLGraph
The input graph.
k : int
The :math:`k` in :math:`A^k`.
Returns
-------
tensor
The returned tensor, dtype is ``np.float32``.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.add_edges([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0])
>>> dgl.khop_adj(g, 1)
tensor([[1., 0., 0., 0., 1.],
[1., 1., 0., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 0., 1., 1., 0.],
[0., 0., 0., 1., 1.]])
>>> dgl.khop_adj(g, 3)
tensor([[1., 0., 1., 3., 3.],
[3., 1., 0., 1., 3.],
[3., 3., 1., 0., 1.],
[1., 3., 3., 1., 0.],
[0., 1., 3., 3., 1.]])
"""
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
return tensor(adj_k.todense().astype(np.float32))
def khop_graph(g, k):
"""Return the graph that includes all :math:`k`-hop neighbors of the given graph as edges.
The adjacency matrix of the returned graph is :math:`A^k`
(where :math:`A` is the adjacency matrix of :math:`g`).
Parameters
----------
g : dgl.DGLGraph
The input graph.
k : int
The :math:`k` in `k`-hop graph.
Returns
-------
dgl.DGLGraph
The returned ``DGLGraph``.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.add_edges([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0])
>>> dgl.khop_graph(g, 1)
DGLGraph(num_nodes=5, num_edges=10,
ndata_schemes={}
edata_schemes={})
>>> dgl.khop_graph(g, 3)
DGLGraph(num_nodes=5, num_edges=40,
ndata_schemes={}
edata_schemes={})
"""
n = g.number_of_nodes()
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
adj_k = adj_k.tocoo()
multiplicity = adj_k.data
row = np.repeat(adj_k.row, multiplicity)
col = np.repeat(adj_k.col, multiplicity)
# TODO(zihao): we should support creating multi-graph from scipy sparse matrix
# in the future.
return DGLGraph(from_coo(n, row, col, True, True))
def reverse(g, share_ndata=False, share_edata=False): def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph """Return the reverse of a graph
...@@ -46,6 +136,7 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -46,6 +136,7 @@ def reverse(g, share_ndata=False, share_edata=False):
Parameters Parameters
---------- ----------
g : dgl.DGLGraph g : dgl.DGLGraph
The input graph.
share_ndata: bool, optional share_ndata: bool, optional
If True, the original graph and the reversed graph share memory for node attributes. If True, the original graph and the reversed graph share memory for node attributes.
Otherwise the reversed graph will not be initialized with node attributes. Otherwise the reversed graph will not be initialized with node attributes.
...@@ -169,4 +260,49 @@ def to_bidirected(g, readonly=True): ...@@ -169,4 +260,49 @@ def to_bidirected(g, readonly=True):
newgidx = _CAPI_DGLToBidirectedMutableGraph(g._graph) newgidx = _CAPI_DGLToBidirectedMutableGraph(g._graph)
return DGLGraph(newgidx) return DGLGraph(newgidx)
def laplacian_lambda_max(g):
"""Return the largest eigenvalue of the normalized symmetric laplacian of g.
The eigenvalue of the normalized symmetric of any graph is less than or equal to 2,
ref: https://en.wikipedia.org/wiki/Laplacian_matrix#Properties
Parameters
----------
g : DGLGraph or BatchedDGLGraph
The input graph, it should be an undirected graph.
Returns
-------
list :
* If the input g is a DGLGraph, the returned value would be
a list with one element, indicating the largest eigenvalue of g.
* If the input g is a BatchedDGLGraph, the returned value would
be a list, where the i-th item indicates the largest eigenvalue
of i-th graph in g.
Examples
--------
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.add_edges([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], [1, 2, 3, 4, 0, 4, 0, 1, 2, 3])
>>> dgl.laplacian_lambda_max(g)
[1.809016994374948]
"""
if isinstance(g, BatchedDGLGraph):
g_arr = unbatch(g)
else:
g_arr = [g]
rst = []
for g_i in g_arr:
n = g_i.number_of_nodes()
adj = g_i.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
norm = sparse.diags(asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float)
laplacian = sparse.eye(n) - norm * adj * norm
rst.append(sparse.linalg.eigs(laplacian, 1, which='LM',
return_eigenvectors=False)[0].real)
return rst
_init_api("dgl.transform") _init_api("dgl.transform")
...@@ -110,6 +110,11 @@ def min(x, dim): ...@@ -110,6 +110,11 @@ def min(x, dim):
def prod(x, dim): def prod(x, dim):
"""Computes the prod of array elements over given axes""" """Computes the prod of array elements over given axes"""
pass pass
def matmul(a, b):
"""Compute Matrix Multiplication between a and b"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -83,6 +83,9 @@ def min(x, dim): ...@@ -83,6 +83,9 @@ def min(x, dim):
def prod(x, dim): def prod(x, dim):
return x.prod(dim) return x.prod(dim)
def matmul(a, b):
return nd.dot(a, b)
record_grad = autograd.record record_grad = autograd.record
......
...@@ -79,6 +79,9 @@ def min(x, dim): ...@@ -79,6 +79,9 @@ def min(x, dim):
def prod(x, dim): def prod(x, dim):
return x.prod(dim) return x.prod(dim)
def matmul(a, b):
return a @ b
class record_grad(object): class record_grad(object):
def __init__(self): def __init__(self):
pass pass
......
...@@ -112,6 +112,56 @@ def test_bidirected_graph(): ...@@ -112,6 +112,56 @@ def test_bidirected_graph():
_test(False, True) _test(False, True)
_test(False, False) _test(False, False)
def test_khop_graph():
N = 20
feat = F.randn((N, 5))
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for k in range(4):
g_k = dgl.khop_graph(g, k)
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop graph to do message passing for one time.
g_k.ndata['h'] = feat
g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_1 = g_k.ndata.pop('h')
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
def test_khop_adj():
N = 20
feat = F.randn((N, 5))
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for k in range(3):
adj = F.tensor(dgl.khop_adj(g, k))
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop adj to do message passing for one time.
h_1 = F.matmul(adj, feat)
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
def test_laplacian_lambda_max():
N = 20
eps = 1e-6
# test DGLGraph
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
l_max = dgl.laplacian_lambda_max(g)
assert (l_max[0] < 2 + eps)
# test BatchedDGLGraph
N_arr = [20, 30, 10, 12]
bg = dgl.batch([
dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for N in N_arr
])
l_max_arr = dgl.laplacian_lambda_max(bg)
assert len(l_max_arr) == len(N_arr)
for l_max in l_max_arr:
assert l_max < 2 + eps
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
test_no_backtracking() test_no_backtracking()
...@@ -119,3 +169,6 @@ if __name__ == '__main__': ...@@ -119,3 +169,6 @@ if __name__ == '__main__':
test_reverse_shared_frames() test_reverse_shared_frames()
test_simple_graph() test_simple_graph()
test_bidirected_graph() test_bidirected_graph()
test_khop_adj()
test_khop_graph()
test_laplacian_lambda_max()
...@@ -20,7 +20,7 @@ def test_graph_conv(): ...@@ -20,7 +20,7 @@ def test_graph_conv():
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm=False, bias=True)
if F.gpu_ctx(): if F.gpu_ctx():
conv.cuda() conv = conv.to(ctx)
print(conv) print(conv)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -37,7 +37,7 @@ def test_graph_conv(): ...@@ -37,7 +37,7 @@ def test_graph_conv():
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
if F.gpu_ctx(): if F.gpu_ctx():
conv.cuda() conv = conv.to(ctx)
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
...@@ -51,7 +51,7 @@ def test_graph_conv(): ...@@ -51,7 +51,7 @@ def test_graph_conv():
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
if F.gpu_ctx(): if F.gpu_ctx():
conv.cuda() conv = conv.to(ctx)
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
...@@ -81,15 +81,15 @@ def _S2AXWb(A, N, X, W, b): ...@@ -81,15 +81,15 @@ def _S2AXWb(A, N, X, W, b):
return Y + b return Y + b
def test_tgconv(): def test_tagconv():
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx) adj = g.adjacency_matrix(ctx=ctx)
norm = th.pow(g.in_degrees().float(), -0.5) norm = th.pow(g.in_degrees().float(), -0.5)
conv = nn.TGConv(5, 2, bias=True) conv = nn.TAGConv(5, 2, bias=True)
if F.gpu_ctx(): if F.gpu_ctx():
conv.cuda() conv = conv.to(ctx)
print(conv) print(conv)
# test#1: basic # test#1: basic
...@@ -102,27 +102,27 @@ def test_tgconv(): ...@@ -102,27 +102,27 @@ def test_tgconv():
assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)) assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))
conv = nn.TGConv(5, 2) conv = nn.TAGConv(5, 2)
if F.gpu_ctx(): if F.gpu_ctx():
conv.cuda() conv = conv.to(ctx)
# test#2: basic # test#2: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert h1.shape[-1] == 2
assert len(g.edata) == 0
# test rest_parameters # test reset_parameters
old_weight = deepcopy(conv.lin.weight.data) old_weight = deepcopy(conv.lin.weight.data)
conv.reset_parameters() conv.reset_parameters()
new_weight = conv.lin.weight.data new_weight = conv.lin.weight.data
assert not F.allclose(old_weight, new_weight) assert not F.allclose(old_weight, new_weight)
def test_set2set(): def test_set2set():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
if F.gpu_ctx(): if F.gpu_ctx():
s2s.cuda() s2s = s2s.to(ctx)
print(s2s) print(s2s)
# test#1: basic # test#1: basic
...@@ -139,11 +139,12 @@ def test_set2set(): ...@@ -139,11 +139,12 @@ def test_set2set():
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2 assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
def test_glob_att_pool(): def test_glob_att_pool():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10)) gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
if F.gpu_ctx(): if F.gpu_ctx():
gap.cuda() gap = gap.to(ctx)
print(gap) print(gap)
# test#1: basic # test#1: basic
...@@ -158,6 +159,7 @@ def test_glob_att_pool(): ...@@ -158,6 +159,7 @@ def test_glob_att_pool():
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
def test_simple_pool(): def test_simple_pool():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(15)) g = dgl.DGLGraph(nx.path_graph(15))
sum_pool = nn.SumPooling() sum_pool = nn.SumPooling()
...@@ -168,6 +170,12 @@ def test_simple_pool(): ...@@ -168,6 +170,12 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
if F.gpu_ctx():
sum_pool = sum_pool.to(ctx)
avg_pool = avg_pool.to(ctx)
max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h0 = h0.to(ctx)
h1 = sum_pool(h0, g) h1 = sum_pool(h0, g)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g) h1 = avg_pool(h0, g)
...@@ -181,6 +189,8 @@ def test_simple_pool(): ...@@ -181,6 +189,8 @@ def test_simple_pool():
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g]) bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
if F.gpu_ctx():
h0 = h0.to(ctx)
h1 = sum_pool(h0, bg) h1 = sum_pool(h0, bg)
truth = th.stack([F.sum(h0[:15], 0), truth = th.stack([F.sum(h0[:15], 0),
...@@ -210,15 +220,16 @@ def test_simple_pool(): ...@@ -210,15 +220,16 @@ def test_simple_pool():
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans(): def test_set_trans():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(15)) g = dgl.DGLGraph(nx.path_graph(15))
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab') st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3) st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4) st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx(): if F.gpu_ctx():
st_enc_0.cuda() st_enc_0 = st_enc_0.to(ctx)
st_enc_1.cuda() st_enc_1 = st_enc_1.to(ctx)
st_dec.cuda() st_dec = st_dec.to(ctx)
print(st_enc_0, st_enc_1, st_dec) print(st_enc_0, st_enc_1, st_dec)
# test#1: basic # test#1: basic
...@@ -354,6 +365,207 @@ def test_rgcn(): ...@@ -354,6 +365,207 @@ def test_rgcn():
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
def test_gat_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gat = nn.GATConv(5, 2, 4)
feat = F.randn((100, 5))
if F.gpu_ctx():
gat = gat.to(ctx)
feat = feat.to(ctx)
h = gat(feat, g)
assert h.shape[-1] == 2 and h.shape[-2] == 4
def test_sage_conv():
for aggre_type in ['mean', 'pool', 'gcn', 'lstm']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
if F.gpu_ctx():
sage = sage.to(ctx)
feat = feat.to(ctx)
h = sage(feat, g)
assert h.shape[-1] == 10
def test_sgc_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
# not cached
sgc = nn.SGConv(5, 10, 3)
feat = F.randn((100, 5))
if F.gpu_ctx():
sgc = sgc.to(ctx)
feat = feat.to(ctx)
h = sgc(feat, g)
assert h.shape[-1] == 10
# cached
sgc = nn.SGConv(5, 10, 3, True)
if F.gpu_ctx():
sgc = sgc.to(ctx)
h_0 = sgc(feat, g)
h_1 = sgc(feat + 1, g)
assert F.allclose(h_0, h_1)
assert h_0.shape[-1] == 10
def test_appnp_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((100, 5))
if F.gpu_ctx():
appnp = appnp.to(ctx)
feat = feat.to(ctx)
h = appnp(feat, g)
assert h.shape[-1] == 5
def test_gin_conv():
for aggregator_type in ['mean', 'max', 'sum']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gin = nn.GINConv(
th.nn.Linear(5, 12),
aggregator_type
)
feat = F.randn((100, 5))
if F.gpu_ctx():
gin = gin.to(ctx)
feat = feat.to(ctx)
h = gin(feat, g)
assert h.shape[-1] == 12
def test_agnn_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
agnn = nn.AGNNConv(1)
feat = F.randn((100, 5))
if F.gpu_ctx():
agnn = agnn.to(ctx)
feat = feat.to(ctx)
h = agnn(feat, g)
assert h.shape[-1] == 5
def test_gated_graph_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
ggconv = nn.GatedGraphConv(5, 10, 5, 3)
etypes = th.arange(g.number_of_edges()) % 3
feat = F.randn((100, 5))
if F.gpu_ctx():
ggconv = ggconv.to(ctx)
feat = feat.to(ctx)
etypes = etypes.to(ctx)
h = ggconv(feat, etypes, g)
# current we only do shape check
assert h.shape[-1] == 10
def test_nn_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
edge_func = th.nn.Linear(4, 5 * 10)
nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((100, 5))
efeat = F.randn((g.number_of_edges(), 4))
if F.gpu_ctx():
nnconv = nnconv.to(ctx)
feat = feat.to(ctx)
efeat = efeat.to(ctx)
h = nnconv(feat, efeat, g)
# currently we only do shape check
assert h.shape[-1] == 10
def test_gmm_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
feat = F.randn((100, 5))
pseudo = F.randn((g.number_of_edges(), 3))
if F.gpu_ctx():
gmmconv = gmmconv.to(ctx)
feat = feat.to(ctx)
pseudo = pseudo.to(ctx)
h = gmmconv(feat, pseudo, g)
# currently we only do shape check
assert h.shape[-1] == 10
def test_dense_graph_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
conv = nn.GraphConv(5, 2, norm=False, bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True)
dense_conv.weight.data = conv.weight.data
dense_conv.bias.data = conv.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
conv = conv.to(ctx)
dense_conv = dense_conv.to(ctx)
feat = feat.to(ctx)
out_conv = conv(feat, g)
out_dense_conv = dense_conv(feat, adj)
assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
sage = nn.SAGEConv(5, 2, 'gcn',)
dense_sage = nn.DenseSAGEConv(5, 2)
dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
sage = sage.to(ctx)
dense_sage = dense_sage.to(ctx)
feat = feat.to(ctx)
out_sage = sage(feat, g)
out_dense_sage = dense_sage(feat, adj)
assert F.allclose(out_sage, out_dense_sage)
def test_dense_cheb_conv():
for k in range(1, 4):
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense()
cheb = nn.ChebConv(5, 2, k)
dense_cheb = nn.DenseChebConv(5, 2, k)
for i in range(len(cheb.fc)):
dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
if cheb.bias is not None:
dense_cheb.bias.data = cheb.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
cheb = cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx)
feat = feat.to(ctx)
out_cheb = cheb(feat, g)
out_dense_cheb = dense_cheb(feat, adj)
assert F.allclose(out_cheb, out_dense_cheb)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
...@@ -362,3 +574,17 @@ if __name__ == '__main__': ...@@ -362,3 +574,17 @@ if __name__ == '__main__':
test_simple_pool() test_simple_pool()
test_set_trans() test_set_trans()
test_rgcn() test_rgcn()
test_tagconv()
test_gat_conv()
test_sage_conv()
test_sgc_conv()
test_appnp_conv()
test_gin_conv()
test_agnn_conv()
test_gated_graph_conv()
test_nn_conv()
test_gmm_conv()
test_dense_graph_conv()
test_dense_sage_conv()
test_dense_cheb_conv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment