"vscode:/vscode.git/clone" did not exist on "ac61eefc9f2fbd4d2190d5673a4fcd77da9a93ab"
Unverified Commit ca302a13 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Add dgl.nn.*.Sequential for usability (#1166)



* upd

* upd

* upd

* upd

* lint

* upd

* upd

* upd

* upd
Co-authored-by: default avatarVoVAllen <VoVAllen@users.noreply.github.com>
parent 3d664628
......@@ -190,6 +190,13 @@ Set2Set
Utility Modules
----------------------------------------
Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.mxnet.utils.Sequential
:members:
:show-inheritance:
Edge Softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -210,6 +210,13 @@ SetTransformerDecoder
Utility Modules
----------------------------------------
Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.utils.Sequential
:members:
:show-inheritance:
KNNGraph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -2,3 +2,4 @@
from .conv import *
from .glob import *
from .softmax import *
from .utils import Sequential
"""Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name
from mxnet import nd
from mxnet import nd, gluon
import numpy as np
from ... import DGLGraph
def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector.
......@@ -105,3 +106,124 @@ def normalize(x, p=2, axis=1, eps=1e-12):
"""
denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf'))
return x / denom
class Sequential(gluon.nn.Sequential):
"""A squential container for stacking graph neural network blocks.
We support two modes: sequentially apply GNN blocks on the same graph or
a list of given graphs. In the second case, the number of graphs equals the
number of blocks inside this container.
Examples
--------
Mode 1: sequentially apply GNN modules on the same graph
>>> import dgl
>>> from mxnet import nd
>>> from mxnet.gluon import nn
>>> import dgl.function as fn
>>> from dgl.nn.mxnet import Sequential
>>> class ExampleLayer(nn.Block):
>>> def __init__(self, **kwargs):
>>> super().__init__(**kwargs)
>>> def forward(self, graph, n_feat, e_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>> e_feat += graph.edata['e']
>>> return n_feat, e_feat
>>>
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> net = Sequential()
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.initialize()
>>> n_feat = nd.random.randn(3, 4)
>>> e_feat = nd.random.randn(9, 4)
>>> net(g, n_feat, e_feat)
(
[[ 12.412863 99.61184 21.472883 -57.625923 ]
[ 10.08097 100.68611 20.627377 -60.13458 ]
[ 11.7912245 101.80654 22.427956 -58.32772 ]]
<NDArray 3x4 @cpu(0)>,
[[ 21.818504 198.12076 42.72387 -115.147736]
[ 23.070837 195.49811 43.42292 -116.17203 ]
[ 24.330334 197.10927 42.40048 -118.06538 ]
[ 21.907919 199.11469 42.1187 -115.35658 ]
[ 22.849625 198.79213 43.866085 -113.65381 ]
[ 20.926125 198.116 42.64334 -114.246704]
[ 23.003159 197.06662 41.796425 -117.14977 ]
[ 21.391375 198.3348 41.428078 -116.30361 ]
[ 21.291483 200.0701 40.8239 -118.07314 ]]
<NDArray 9x4 @cpu(0)>)
Mode 2: sequentially apply GNN modules on different graphs
>>> import dgl
>>> from mxnet import nd
>>> from mxnet.gluon import nn
>>> import dgl.function as fn
>>> import networkx as nx
>>> from dgl.nn.mxnet import Sequential
>>> class ExampleLayer(nn.Block):
>>> def __init__(self, **kwargs):
>>> super().__init__(**kwargs)
>>> def forward(self, graph, n_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
>>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
>>> net = Sequential()
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.add(ExampleLayer())
>>> net.initialize()
>>> n_feat = nd.random.randn(32, 4)
>>> net([g1, g2, g3], n_feat)
[[-101.289566 -22.584694 -89.25348 -151.6447 ]
[-130.74239 -49.494812 -120.250854 -199.81546 ]
[-112.32089 -50.036713 -116.13266 -190.38638 ]
[-119.23065 -26.78553 -111.11185 -166.08322 ]]
<NDArray 4x4 @cpu(0)>
"""
def __init__(self, prefix=None, params=None):
super(Sequential, self).__init__(prefix=prefix, params=params)
def forward(self, graph, *feats):
"""Sequentially apply modules to the input.
Parameters
----------
graph: a DGLGraph or a list of DGLGraphs.
*feats: input features.
The output of i-th block should match that of the input
of (i+1)-th block.
"""
if isinstance(graph, list):
for graph_i, module in zip(graph, self):
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph_i, *feats)
elif isinstance(graph, DGLGraph):
for module in self:
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph, *feats)
else:
raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s')
return feats
......@@ -3,3 +3,4 @@ from .conv import *
from .glob import *
from .softmax import *
from .factory import *
from .utils import Sequential
......@@ -3,6 +3,7 @@
import torch as th
from torch import nn
from ... import DGLGraph
def matmul_maybe_select(A, B):
......@@ -101,3 +102,116 @@ class Identity(nn.Module):
def forward(self, x):
"""Return input"""
return x
class Sequential(nn.Sequential):
"""A squential container for stacking graph neural network modules.
We support two modes: sequentially apply GNN modules on the same graph or
a list of given graphs. In the second case, the number of graphs equals the
number of modules inside this container.
Parameters
----------
*args : sub-modules of type torch.nn.Module, will be added to the container in
the order they are passed in the constructor.
Examples
--------
Mode 1: sequentially apply GNN modules on the same graph
>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> def forward(self, graph, n_feat, e_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>> e_feat += graph.edata['e']
>>> return n_feat, e_feat
>>>
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(3, 4)
>>> e_feat = torch.rand(9, 4)
>>> net(g, n_feat, e_feat)
(tensor([[39.8597, 45.4542, 25.1877, 30.8086],
[40.7095, 45.3985, 25.4590, 30.0134],
[40.7894, 45.2556, 25.5221, 30.4220]]), tensor([[80.3772, 89.7752, 50.7762, 60.5520],
[80.5671, 89.3736, 50.6558, 60.6418],
[80.4620, 89.5142, 50.3643, 60.3126],
[80.4817, 89.8549, 50.9430, 59.9108],
[80.2284, 89.6954, 50.0448, 60.1139],
[79.7846, 89.6882, 50.5097, 60.6213],
[80.2654, 90.2330, 50.2787, 60.6937],
[80.3468, 90.0341, 50.2062, 60.2659],
[80.0556, 90.2789, 50.2882, 60.5845]]))
Mode 2: sequentially apply GNN modules on different graphs
>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> import networkx as nx
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> def forward(self, graph, n_feat):
>>> graph = graph.local_var()
>>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h']
>>> return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
>>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(32, 4)
>>> net([g1, g2, g3], n_feat)
tensor([[209.6221, 225.5312, 193.8920, 220.1002],
[250.0169, 271.9156, 240.2467, 267.7766],
[220.4007, 239.7365, 213.8648, 234.9637],
[196.4630, 207.6319, 184.2927, 208.7465]])
"""
def __init__(self, *args):
super(Sequential, self).__init__(*args)
def forward(self, graph, *feats):
"""Sequentially apply modules to the input.
Parameters
----------
graph: a DGLGraph or a list of DGLGraphs.
*feats: input features.
The output of i-th block should match that of the input
of (i+1)-th block.
"""
if isinstance(graph, list):
for graph_i, module in zip(graph, self):
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph_i, *feats)
elif isinstance(graph, DGLGraph):
for module in self:
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph, *feats)
else:
raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s')
return feats
......@@ -4,6 +4,7 @@ import numpy as np
import scipy as sp
import dgl
import dgl.nn.mxnet as nn
import dgl.function as fn
import backend as F
from mxnet import autograd, gluon, nd
......@@ -508,6 +509,60 @@ def test_rgcn():
h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O]
def test_sequential():
ctx = F.ctx()
# test single graph
class ExampleLayer(gluon.nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, graph, n_feat, e_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
e_feat += graph.edata['e']
return n_feat, e_feat
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
net = nn.Sequential()
net.add(ExampleLayer())
net.add(ExampleLayer())
net.add(ExampleLayer())
net.initialize(ctx=ctx)
n_feat = F.randn((3, 4))
e_feat = F.randn((9, 4))
n_feat, e_feat = net(g, n_feat, e_feat)
assert n_feat.shape == (3, 4)
assert e_feat.shape == (9, 4)
# test multiple graphs
class ExampleLayer(gluon.nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, graph, n_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)
g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
net = nn.Sequential()
net.add(ExampleLayer())
net.add(ExampleLayer())
net.add(ExampleLayer())
net.initialize(ctx=ctx)
n_feat = F.randn((32, 4))
n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4)
if __name__ == '__main__':
test_graph_conv()
test_gat_conv()
......@@ -530,3 +585,4 @@ if __name__ == '__main__':
test_glob_att_pool()
test_simple_pool()
test_rgcn()
test_sequential()
......@@ -2,6 +2,7 @@ import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
import dgl.function as fn
import backend as F
from copy import deepcopy
......@@ -19,7 +20,6 @@ def test_graph_conv():
adj = g.adjacency_matrix(ctx=ctx)
conv = nn.GraphConv(5, 2, norm=False, bias=True)
if F.gpu_ctx():
conv = conv.to(ctx)
print(conv)
# test#1: basic
......@@ -36,7 +36,6 @@ def test_graph_conv():
assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2)
if F.gpu_ctx():
conv = conv.to(ctx)
# test#3: basic
h0 = F.ones((3, 5))
......@@ -50,7 +49,6 @@ def test_graph_conv():
assert len(g.edata) == 0
conv = nn.GraphConv(5, 2)
if F.gpu_ctx():
conv = conv.to(ctx)
# test#3: basic
h0 = F.ones((3, 5))
......@@ -88,7 +86,6 @@ def test_tagconv():
norm = th.pow(g.in_degrees().float(), -0.5)
conv = nn.TAGConv(5, 2, bias=True)
if F.gpu_ctx():
conv = conv.to(ctx)
print(conv)
......@@ -103,7 +100,6 @@ def test_tagconv():
assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))
conv = nn.TAGConv(5, 2)
if F.gpu_ctx():
conv = conv.to(ctx)
# test#2: basic
......@@ -122,7 +118,6 @@ def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
if F.gpu_ctx():
s2s = s2s.to(ctx)
print(s2s)
......@@ -144,7 +139,6 @@ def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
if F.gpu_ctx():
gap = gap.to(ctx)
print(gap)
......@@ -171,12 +165,10 @@ def test_simple_pool():
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
if F.gpu_ctx():
sum_pool = sum_pool.to(ctx)
avg_pool = avg_pool.to(ctx)
max_pool = max_pool.to(ctx)
sort_pool = sort_pool.to(ctx)
h0 = h0.to(ctx)
h1 = sum_pool(g, h0)
assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(g, h0)
......@@ -190,9 +182,6 @@ def test_simple_pool():
g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5))
if F.gpu_ctx():
h0 = h0.to(ctx)
h1 = sum_pool(bg, h0)
truth = th.stack([F.sum(h0[:15], 0),
F.sum(h0[15:20], 0),
......@@ -227,7 +216,6 @@ def test_set_trans():
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx():
st_enc_0 = st_enc_0.to(ctx)
st_enc_1 = st_enc_1.to(ctx)
st_dec = st_dec.to(ctx)
......@@ -400,10 +388,7 @@ def test_gat_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gat = nn.GATConv(5, 2, 4)
feat = F.randn((100, 5))
if F.gpu_ctx():
gat = gat.to(ctx)
h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4
......@@ -413,10 +398,7 @@ def test_sage_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
if F.gpu_ctx():
sage = sage.to(ctx)
h = sage(g, feat)
assert h.shape[-1] == 10
......@@ -426,8 +408,6 @@ def test_sgc_conv():
# not cached
sgc = nn.SGConv(5, 10, 3)
feat = F.randn((100, 5))
if F.gpu_ctx():
sgc = sgc.to(ctx)
h = sgc(g, feat)
......@@ -435,10 +415,7 @@ def test_sgc_conv():
# cached
sgc = nn.SGConv(5, 10, 3, True)
if F.gpu_ctx():
sgc = sgc.to(ctx)
h_0 = sgc(g, feat)
h_1 = sgc(g, feat + 1)
assert F.allclose(h_0, h_1)
......@@ -449,8 +426,6 @@ def test_appnp_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((100, 5))
if F.gpu_ctx():
appnp = appnp.to(ctx)
h = appnp(g, feat)
......@@ -465,10 +440,7 @@ def test_gin_conv():
aggregator_type
)
feat = F.randn((100, 5))
if F.gpu_ctx():
gin = gin.to(ctx)
h = gin(g, feat)
assert h.shape[-1] == 12
......@@ -477,10 +449,7 @@ def test_agnn_conv():
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
agnn = nn.AGNNConv(1)
feat = F.randn((100, 5))
if F.gpu_ctx():
agnn = agnn.to(ctx)
h = agnn(g, feat)
assert h.shape[-1] == 5
......@@ -490,8 +459,6 @@ def test_gated_graph_conv():
ggconv = nn.GatedGraphConv(5, 10, 5, 3)
etypes = th.arange(g.number_of_edges()) % 3
feat = F.randn((100, 5))
if F.gpu_ctx():
ggconv = ggconv.to(ctx)
etypes = etypes.to(ctx)
......@@ -506,10 +473,7 @@ def test_nn_conv():
nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((100, 5))
efeat = F.randn((g.number_of_edges(), 4))
if F.gpu_ctx():
nnconv = nnconv.to(ctx)
h = nnconv(g, feat, efeat)
# currently we only do shape check
assert h.shape[-1] == 10
......@@ -520,10 +484,7 @@ def test_gmm_conv():
gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
feat = F.randn((100, 5))
pseudo = F.randn((g.number_of_edges(), 3))
if F.gpu_ctx():
gmmconv = gmmconv.to(ctx)
h = gmmconv(g, feat, pseudo)
# currently we only do shape check
assert h.shape[-1] == 10
......@@ -537,10 +498,8 @@ def test_dense_graph_conv():
dense_conv.weight.data = conv.weight.data
dense_conv.bias.data = conv.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
conv = conv.to(ctx)
dense_conv = dense_conv.to(ctx)
out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv)
......@@ -554,10 +513,8 @@ def test_dense_sage_conv():
dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
sage = sage.to(ctx)
dense_sage = dense_sage.to(ctx)
out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage)
......@@ -574,14 +531,60 @@ def test_dense_cheb_conv():
if cheb.bias is not None:
dense_cheb.bias.data = cheb.bias.data
feat = F.randn((100, 5))
if F.gpu_ctx():
cheb = cheb.to(ctx)
dense_cheb = dense_cheb.to(ctx)
out_cheb = cheb(g, feat, [2.0])
out_dense_cheb = dense_cheb(adj, feat, 2.0)
assert F.allclose(out_cheb, out_dense_cheb)
def test_sequential():
ctx = F.ctx()
# Test single graph
class ExampleLayer(th.nn.Module):
def __init__(self):
super().__init__()
def forward(self, graph, n_feat, e_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
e_feat += graph.edata['e']
return n_feat, e_feat
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
n_feat = F.randn((3, 4))
e_feat = F.randn((9, 4))
net = net.to(ctx)
n_feat, e_feat = net(g, n_feat, e_feat)
assert n_feat.shape == (3, 4)
assert e_feat.shape == (9, 4)
# Test multiple graph
class ExampleLayer(th.nn.Module):
def __init__(self):
super().__init__()
def forward(self, graph, n_feat):
graph = graph.local_var()
graph.ndata['h'] = n_feat
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
n_feat += graph.ndata['h']
return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
net = net.to(ctx)
n_feat = F.randn((32, 4))
n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4)
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
......@@ -604,4 +607,5 @@ if __name__ == '__main__':
test_dense_graph_conv()
test_dense_sage_conv()
test_dense_cheb_conv()
test_sequential()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment