Unverified Commit ca302a13 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Add dgl.nn.*.Sequential for usability (#1166)



* upd

* upd

* upd

* upd

* lint

* upd

* upd

* upd

* upd
Co-authored-by: default avatarVoVAllen <VoVAllen@users.noreply.github.com>
parent 3d664628
...@@ -190,6 +190,13 @@ Set2Set ...@@ -190,6 +190,13 @@ Set2Set
Utility Modules Utility Modules
---------------------------------------- ----------------------------------------
Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.mxnet.utils.Sequential
:members:
:show-inheritance:
Edge Softmax Edge Softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -210,6 +210,13 @@ SetTransformerDecoder ...@@ -210,6 +210,13 @@ SetTransformerDecoder
Utility Modules Utility Modules
---------------------------------------- ----------------------------------------
Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.utils.Sequential
:members:
:show-inheritance:
KNNGraph KNNGraph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
from .conv import * from .conv import *
from .glob import * from .glob import *
from .softmax import * from .softmax import *
from .utils import Sequential
"""Utilities for pytorch NN package""" """Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name #pylint: disable=no-member, invalid-name
from mxnet import nd from mxnet import nd, gluon
import numpy as np import numpy as np
from ... import DGLGraph
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector. """Perform Matrix multiplication C = A * B but A could be an integer id vector.
...@@ -105,3 +106,124 @@ def normalize(x, p=2, axis=1, eps=1e-12): ...@@ -105,3 +106,124 @@ def normalize(x, p=2, axis=1, eps=1e-12):
""" """
denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf')) denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf'))
return x / denom return x / denom
class Sequential(gluon.nn.Sequential):
"""A squential container for stacking graph neural network blocks.
We support two modes: sequentially apply GNN blocks on the same graph or
a list of given graphs. In the second case, the number of graphs equals the
number of blocks inside this container.
Examples
--------
Mode 1: sequentially apply GNN modules on the same graph
>>> import dgl
>>> from mxnet import nd
>>> from mxnet.gluon import nn
>>> import dgl.function as fn
>>> from dgl.nn.mxnet import Sequential
>>> class ExampleLayer(nn.Block):
>>> def __init__(self, **kwargs):
>>> super().__init__(**kwargs)
>>> def forward(self, graph, n_feat, e_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>> e_feat += graph.edata['e']
>>> return n_feat, e_feat
>>>
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> net = Sequential()
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.initialize()
>>> n_feat = nd.random.randn(3, 4)
>>> e_feat = nd.random.randn(9, 4)
>>> net(g, n_feat, e_feat)
(
[[ 12.412863 99.61184 21.472883 -57.625923 ]
[ 10.08097 100.68611 20.627377 -60.13458 ]
[ 11.7912245 101.80654 22.427956 -58.32772 ]]
<NDArray 3x4 @cpu(0)>,
[[ 21.818504 198.12076 42.72387 -115.147736]
[ 23.070837 195.49811 43.42292 -116.17203 ]
[ 24.330334 197.10927 42.40048 -118.06538 ]
[ 21.907919 199.11469 42.1187 -115.35658 ]
[ 22.849625 198.79213 43.866085 -113.65381 ]
[ 20.926125 198.116 42.64334 -114.246704]
[ 23.003159 197.06662 41.796425 -117.14977 ]
[ 21.391375 198.3348 41.428078 -116.30361 ]
[ 21.291483 200.0701 40.8239 -118.07314 ]]
<NDArray 9x4 @cpu(0)>)
Mode 2: sequentially apply GNN modules on different graphs
>>> import dgl
>>> from mxnet import nd
>>> from mxnet.gluon import nn
>>> import dgl.function as fn
>>> import networkx as nx
>>> from dgl.nn.mxnet import Sequential
>>> class ExampleLayer(nn.Block):
>>> def __init__(self, **kwargs):
>>> super().__init__(**kwargs)
>>> def forward(self, graph, n_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
>>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
>>> net = Sequential()
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.initialize()
>>> n_feat = nd.random.randn(32, 4)
>>> net([g1, g2, g3], n_feat)
[[-101.289566 -22.584694 -89.25348 -151.6447 ]
[-130.74239 -49.494812 -120.250854 -199.81546 ]
[-112.32089 -50.036713 -116.13266 -190.38638 ]
[-119.23065 -26.78553 -111.11185 -166.08322 ]]
<NDArray 4x4 @cpu(0)>
"""
def __init__(self, prefix=None, params=None):
super(Sequential, self).__init__(prefix=prefix, params=params)
def forward(self, graph, *feats):
"""Sequentially apply modules to the input.
Parameters
----------
graph: a DGLGraph or a list of DGLGraphs.
*feats: input features.
The output of i-th block should match that of the input
of (i+1)-th block.
"""
if isinstance(graph, list):
for graph_i, module in zip(graph, self):
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph_i, *feats)
elif isinstance(graph, DGLGraph):
for module in self:
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph, *feats)
else:
raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s')
return feats
...@@ -3,3 +3,4 @@ from .conv import * ...@@ -3,3 +3,4 @@ from .conv import *
from .glob import * from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
from .utils import Sequential
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch as th import torch as th
from torch import nn from torch import nn
from ... import DGLGraph
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
...@@ -101,3 +102,116 @@ class Identity(nn.Module): ...@@ -101,3 +102,116 @@ class Identity(nn.Module):
def forward(self, x): def forward(self, x):
"""Return input""" """Return input"""
return x return x
class Sequential(nn.Sequential):
"""A squential container for stacking graph neural network modules.
We support two modes: sequentially apply GNN modules on the same graph or
a list of given graphs. In the second case, the number of graphs equals the
number of modules inside this container.
Parameters
----------
*args : sub-modules of type torch.nn.Module, will be added to the container in
the order they are passed in the constructor.
Examples
--------
Mode 1: sequentially apply GNN modules on the same graph
>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> def forward(self, graph, n_feat, e_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>> e_feat += graph.edata['e']
>>> return n_feat, e_feat
>>>
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(3, 4)
>>> e_feat = torch.rand(9, 4)
>>> net(g, n_feat, e_feat)
(tensor([[39.8597, 45.4542, 25.1877, 30.8086],
[40.7095, 45.3985, 25.4590, 30.0134],
[40.7894, 45.2556, 25.5221, 30.4220]]), tensor([[80.3772, 89.7752, 50.7762, 60.5520],
[80.5671, 89.3736, 50.6558, 60.6418],
[80.4620, 89.5142, 50.3643, 60.3126],
[80.4817, 89.8549, 50.9430, 59.9108],
[80.2284, 89.6954, 50.0448, 60.1139],
[79.7846, 89.6882, 50.5097, 60.6213],
[80.2654, 90.2330, 50.2787, 60.6937],
[80.3468, 90.0341, 50.2062, 60.2659],
[80.0556, 90.2789, 50.2882, 60.5845]]))
Mode 2: sequentially apply GNN modules on different graphs
>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> import networkx as nx
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> def forward(self, graph, n_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
>>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(32, 4)
>>> net([g1, g2, g3], n_feat)
tensor([[209.6221, 225.5312, 193.8920, 220.1002],
[250.0169, 271.9156, 240.2467, 267.7766],
[220.4007, 239.7365, 213.8648, 234.9637],
[196.4630, 207.6319, 184.2927, 208.7465]])
"""
def __init__(self, *args):
super(Sequential, self).__init__(*args)
def forward(self, graph, *feats):
"""Sequentially apply modules to the input.
Parameters
----------
graph: a DGLGraph or a list of DGLGraphs.
*feats: input features.
The output of i-th block should match that of the input
of (i+1)-th block.
"""
if isinstance(graph, list):
for graph_i, module in zip(graph, self):
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph_i, *feats)
elif isinstance(graph, DGLGraph):
for module in self:
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph, *feats)
else:
raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s')
return feats
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import scipy as sp import scipy as sp
import dgl import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
import dgl.function as fn
import backend as F import backend as F
from mxnet import autograd, gluon, nd from mxnet import autograd, gluon, nd
...@@ -508,6 +509,60 @@ def test_rgcn(): ...@@ -508,6 +509,60 @@ def test_rgcn():
h_new = rgc_basis(g, h, r) h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O] assert list(h_new.shape) == [100, O]
def test_sequential():
ctx = F.ctx()
# test single graph
class ExampleLayer(gluon.nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, graph, n_feat, e_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
e_feat += graph.edata['e']
return n_feat, e_feat
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
net = nn.Sequential()
net.add(ExampleLayer())
net.add(ExampleLayer())
net.add(ExampleLayer())
net.initialize(ctx=ctx)
n_feat = F.randn((3, 4))
e_feat = F.randn((9, 4))
n_feat, e_feat = net(g, n_feat, e_feat)
assert n_feat.shape == (3, 4)
assert e_feat.shape == (9, 4)
# test multiple graphs
class ExampleLayer(gluon.nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, graph, n_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)
g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
net = nn.Sequential()
net.add(ExampleLayer())
net.add(ExampleLayer())
net.add(ExampleLayer())
net.initialize(ctx=ctx)
n_feat = F.randn((32, 4))
n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_gat_conv() test_gat_conv()
...@@ -530,3 +585,4 @@ if __name__ == '__main__': ...@@ -530,3 +585,4 @@ if __name__ == '__main__':
test_glob_att_pool() test_glob_att_pool()
test_simple_pool() test_simple_pool()
test_rgcn() test_rgcn()
test_sequential()
...@@ -2,6 +2,7 @@ import torch as th ...@@ -2,6 +2,7 @@ import torch as th
import networkx as nx import networkx as nx
import dgl import dgl
import dgl.nn.pytorch as nn import dgl.nn.pytorch as nn
import dgl.function as fn
import backend as F import backend as F
from copy import deepcopy from copy import deepcopy
...@@ -19,8 +20,7 @@ def test_graph_conv(): ...@@ -19,8 +20,7 @@ def test_graph_conv():
adj = g.adjacency_matrix(ctx=ctx) adj = g.adjacency_matrix(ctx=ctx)
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm=False, bias=True)
if F.gpu_ctx(): conv = conv.to(ctx)
conv = conv.to(ctx)
print(conv) print(conv)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -36,8 +36,7 @@ def test_graph_conv(): ...@@ -36,8 +36,7 @@ def test_graph_conv():
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
if F.gpu_ctx(): conv = conv.to(ctx)
conv = conv.to(ctx)
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(g, h0) h1 = conv(g, h0)
...@@ -50,8 +49,7 @@ def test_graph_conv(): ...@@ -50,8 +49,7 @@ def test_graph_conv():
assert len(g.edata) == 0 assert len(g.edata) == 0
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
if F.gpu_ctx(): conv = conv.to(ctx)
conv = conv.to(ctx)
# test#3: basic # test#3: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(g, h0) h1 = conv(g, h0)
...@@ -88,8 +86,7 @@ def test_tagconv(): ...@@ -88,8 +86,7 @@ def test_tagconv():
norm = th.pow(g.in_degrees().float(), -0.5) norm = th.pow(g.in_degrees().float(), -0.5)
conv = nn.TAGConv(5, 2, bias=True) conv = nn.TAGConv(5, 2, bias=True)
if F.gpu_ctx(): conv = conv.to(ctx)
conv = conv.to(ctx)
print(conv) print(conv)
# test#1: basic # test#1: basic
...@@ -103,8 +100,7 @@ def test_tagconv(): ...@@ -103,8 +100,7 @@ def test_tagconv():
assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)) assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))
conv = nn.TAGConv(5, 2) conv = nn.TAGConv(5, 2)
if F.gpu_ctx(): conv = conv.to(ctx)
conv = conv.to(ctx)
# test#2: basic # test#2: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -122,8 +118,7 @@ def test_set2set(): ...@@ -122,8 +118,7 @@ def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
if F.gpu_ctx(): s2s = s2s.to(ctx)
s2s = s2s.to(ctx)
print(s2s) print(s2s)
# test#1: basic # test#1: basic
...@@ -144,8 +139,7 @@ def test_glob_att_pool(): ...@@ -144,8 +139,7 @@ def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10)) gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
if F.gpu_ctx(): gap = gap.to(ctx)
gap = gap.to(ctx)
print(gap) print(gap)
# test#1: basic # test#1: basic
...@@ -171,12 +165,10 @@ def test_simple_pool(): ...@@ -171,12 +165,10 @@ def test_simple_pool():
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
if F.gpu_ctx(): sum_pool = sum_pool.to(ctx)
sum_pool = sum_pool.to(ctx) avg_pool = avg_pool.to(ctx)
avg_pool = avg_pool.to(ctx) max_pool = max_pool.to(ctx)
max_pool = max_pool.to(ctx) sort_pool = sort_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h0 = h0.to(ctx)
h1 = sum_pool(g, h0) h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0)) assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(g, h0) h1 = avg_pool(g, h0)
...@@ -190,9 +182,6 @@ def test_simple_pool(): ...@@ -190,9 +182,6 @@ def test_simple_pool():
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g]) bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
if F.gpu_ctx():
h0 = h0.to(ctx)
h1 = sum_pool(bg, h0) h1 = sum_pool(bg, h0)
truth = th.stack([F.sum(h0[:15], 0), truth = th.stack([F.sum(h0[:15], 0),
F.sum(h0[15:20], 0), F.sum(h0[15:20], 0),
...@@ -227,10 +216,9 @@ def test_set_trans(): ...@@ -227,10 +216,9 @@ def test_set_trans():
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab') st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3) st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4) st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx(): st_enc_0 = st_enc_0.to(ctx)
st_enc_0 = st_enc_0.to(ctx) st_enc_1 = st_enc_1.to(ctx)
st_enc_1 = st_enc_1.to(ctx) st_dec = st_dec.to(ctx)
st_dec = st_dec.to(ctx)
print(st_enc_0, st_enc_1, st_dec) print(st_enc_0, st_enc_1, st_dec)
# test#1: basic # test#1: basic
...@@ -400,10 +388,7 @@ def test_gat_conv(): ...@@ -400,10 +388,7 @@ def test_gat_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gat = nn.GATConv(5, 2, 4) gat = nn.GATConv(5, 2, 4)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
gat = gat.to(ctx)
if F.gpu_ctx():
gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4 assert h.shape[-1] == 2 and h.shape[-2] == 4
...@@ -413,10 +398,7 @@ def test_sage_conv(): ...@@ -413,10 +398,7 @@ def test_sage_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
sage = sage.to(ctx)
if F.gpu_ctx():
sage = sage.to(ctx)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -426,19 +408,14 @@ def test_sgc_conv(): ...@@ -426,19 +408,14 @@ def test_sgc_conv():
# not cached # not cached
sgc = nn.SGConv(5, 10, 3) sgc = nn.SGConv(5, 10, 3)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
sgc = sgc.to(ctx)
if F.gpu_ctx():
sgc = sgc.to(ctx)
h = sgc(g, feat) h = sgc(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
# cached # cached
sgc = nn.SGConv(5, 10, 3, True) sgc = nn.SGConv(5, 10, 3, True)
sgc = sgc.to(ctx)
if F.gpu_ctx():
sgc = sgc.to(ctx)
h_0 = sgc(g, feat) h_0 = sgc(g, feat)
h_1 = sgc(g, feat + 1) h_1 = sgc(g, feat + 1)
assert F.allclose(h_0, h_1) assert F.allclose(h_0, h_1)
...@@ -449,9 +426,7 @@ def test_appnp_conv(): ...@@ -449,9 +426,7 @@ def test_appnp_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
appnp = nn.APPNPConv(10, 0.1) appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
appnp = appnp.to(ctx)
if F.gpu_ctx():
appnp = appnp.to(ctx)
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
...@@ -465,10 +440,7 @@ def test_gin_conv(): ...@@ -465,10 +440,7 @@ def test_gin_conv():
aggregator_type aggregator_type
) )
feat = F.randn((100, 5)) feat = F.randn((100, 5))
gin = gin.to(ctx)
if F.gpu_ctx():
gin = gin.to(ctx)
h = gin(g, feat) h = gin(g, feat)
assert h.shape[-1] == 12 assert h.shape[-1] == 12
...@@ -477,10 +449,7 @@ def test_agnn_conv(): ...@@ -477,10 +449,7 @@ def test_agnn_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
agnn = nn.AGNNConv(1) agnn = nn.AGNNConv(1)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
agnn = agnn.to(ctx)
if F.gpu_ctx():
agnn = agnn.to(ctx)
h = agnn(g, feat) h = agnn(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
...@@ -490,10 +459,8 @@ def test_gated_graph_conv(): ...@@ -490,10 +459,8 @@ def test_gated_graph_conv():
ggconv = nn.GatedGraphConv(5, 10, 5, 3) ggconv = nn.GatedGraphConv(5, 10, 5, 3)
etypes = th.arange(g.number_of_edges()) % 3 etypes = th.arange(g.number_of_edges()) % 3
feat = F.randn((100, 5)) feat = F.randn((100, 5))
ggconv = ggconv.to(ctx)
if F.gpu_ctx(): etypes = etypes.to(ctx)
ggconv = ggconv.to(ctx)
etypes = etypes.to(ctx)
h = ggconv(g, feat, etypes) h = ggconv(g, feat, etypes)
# current we only do shape check # current we only do shape check
...@@ -506,10 +473,7 @@ def test_nn_conv(): ...@@ -506,10 +473,7 @@ def test_nn_conv():
nnconv = nn.NNConv(5, 10, edge_func, 'mean') nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((100, 5)) feat = F.randn((100, 5))
efeat = F.randn((g.number_of_edges(), 4)) efeat = F.randn((g.number_of_edges(), 4))
nnconv = nnconv.to(ctx)
if F.gpu_ctx():
nnconv = nnconv.to(ctx)
h = nnconv(g, feat, efeat) h = nnconv(g, feat, efeat)
# currently we only do shape check # currently we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -520,10 +484,7 @@ def test_gmm_conv(): ...@@ -520,10 +484,7 @@ def test_gmm_conv():
gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean') gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
feat = F.randn((100, 5)) feat = F.randn((100, 5))
pseudo = F.randn((g.number_of_edges(), 3)) pseudo = F.randn((g.number_of_edges(), 3))
gmmconv = gmmconv.to(ctx)
if F.gpu_ctx():
gmmconv = gmmconv.to(ctx)
h = gmmconv(g, feat, pseudo) h = gmmconv(g, feat, pseudo)
# currently we only do shape check # currently we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -537,10 +498,8 @@ def test_dense_graph_conv(): ...@@ -537,10 +498,8 @@ def test_dense_graph_conv():
dense_conv.weight.data = conv.weight.data dense_conv.weight.data = conv.weight.data
dense_conv.bias.data = conv.bias.data dense_conv.bias.data = conv.bias.data
feat = F.randn((100, 5)) feat = F.randn((100, 5))
if F.gpu_ctx(): conv = conv.to(ctx)
conv = conv.to(ctx) dense_conv = dense_conv.to(ctx)
dense_conv = dense_conv.to(ctx)
out_conv = conv(g, feat) out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat) out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv) assert F.allclose(out_conv, out_dense_conv)
...@@ -554,10 +513,8 @@ def test_dense_sage_conv(): ...@@ -554,10 +513,8 @@ def test_dense_sage_conv():
dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data dense_sage.fc.bias.data = sage.fc_neigh.bias.data
feat = F.randn((100, 5)) feat = F.randn((100, 5))
if F.gpu_ctx(): sage = sage.to(ctx)
sage = sage.to(ctx) dense_sage = dense_sage.to(ctx)
dense_sage = dense_sage.to(ctx)
out_sage = sage(g, feat) out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat) out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage) assert F.allclose(out_sage, out_dense_sage)
...@@ -574,14 +531,60 @@ def test_dense_cheb_conv(): ...@@ -574,14 +531,60 @@ def test_dense_cheb_conv():
if cheb.bias is not None: if cheb.bias is not None:
dense_cheb.bias.data = cheb.bias.data dense_cheb.bias.data = cheb.bias.data
feat = F.randn((100, 5)) feat = F.randn((100, 5))
if F.gpu_ctx(): cheb = cheb.to(ctx)
cheb = cheb.to(ctx) dense_cheb = dense_cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx)
out_cheb = cheb(g, feat, [2.0]) out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0) out_dense_cheb = dense_cheb(adj, feat, 2.0)
assert F.allclose(out_cheb, out_dense_cheb) assert F.allclose(out_cheb, out_dense_cheb)
def test_sequential():
ctx = F.ctx()
# Test single graph
class ExampleLayer(th.nn.Module):
def __init__(self):
super().__init__()
def forward(self, graph, n_feat, e_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
e_feat += graph.edata['e']
return n_feat, e_feat
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
n_feat = F.randn((3, 4))
e_feat = F.randn((9, 4))
net = net.to(ctx)
n_feat, e_feat = net(g, n_feat, e_feat)
assert n_feat.shape == (3, 4)
assert e_feat.shape == (9, 4)
# Test multiple graph
class ExampleLayer(th.nn.Module):
def __init__(self):
super().__init__()
def forward(self, graph, n_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
net = net.to(ctx)
n_feat = F.randn((32, 4))
n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
...@@ -604,4 +607,5 @@ if __name__ == '__main__': ...@@ -604,4 +607,5 @@ if __name__ == '__main__':
test_dense_graph_conv() test_dense_graph_conv()
test_dense_sage_conv() test_dense_sage_conv()
test_dense_cheb_conv() test_dense_cheb_conv()
test_sequential()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment