Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9314aabd
Unverified
Commit
9314aabd
authored
Aug 27, 2019
by
Zihao Ye
Committed by
GitHub
Aug 27, 2019
Browse files
[Refactor] Interface of nn modules (#798)
* refactor * upd mpnn
parent
650f6ee1
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
194 additions
and
184 deletions
+194
-184
examples/mxnet/gcn/gcn.py
examples/mxnet/gcn/gcn.py
+1
-1
examples/pytorch/appnp/appnp.py
examples/pytorch/appnp/appnp.py
+1
-1
examples/pytorch/gat/gat.py
examples/pytorch/gat/gat.py
+2
-2
examples/pytorch/gcn/gcn.py
examples/pytorch/gcn/gcn.py
+1
-1
examples/pytorch/gin/gin.py
examples/pytorch/gin/gin.py
+1
-1
examples/pytorch/graphsage/graphsage.py
examples/pytorch/graphsage/graphsage.py
+1
-1
examples/pytorch/model_zoo/citation_network/conf.py
examples/pytorch/model_zoo/citation_network/conf.py
+1
-1
examples/pytorch/model_zoo/citation_network/models.py
examples/pytorch/model_zoo/citation_network/models.py
+10
-10
examples/pytorch/sgc/sgc.py
examples/pytorch/sgc/sgc.py
+2
-2
examples/pytorch/sgc/sgc_reddit.py
examples/pytorch/sgc/sgc_reddit.py
+2
-2
examples/pytorch/tagcn/tagcn.py
examples/pytorch/tagcn/tagcn.py
+1
-1
python/dgl/model_zoo/chem/mpnn.py
python/dgl/model_zoo/chem/mpnn.py
+1
-1
python/dgl/nn/mxnet/conv.py
python/dgl/nn/mxnet/conv.py
+3
-3
python/dgl/nn/mxnet/glob.py
python/dgl/nn/mxnet/glob.py
+18
-18
python/dgl/nn/pytorch/conv.py
python/dgl/nn/pytorch/conv.py
+63
-53
python/dgl/nn/pytorch/glob.py
python/dgl/nn/pytorch/glob.py
+24
-24
tests/mxnet/test_nn.py
tests/mxnet/test_nn.py
+19
-19
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+43
-43
No files found.
examples/mxnet/gcn/gcn.py
View file @
9314aabd
...
...
@@ -36,5 +36,5 @@ class GCN(gluon.Block):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
!=
0
:
h
=
self
.
dropout
(
h
)
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
examples/pytorch/appnp/appnp.py
View file @
9314aabd
...
...
@@ -53,5 +53,5 @@ class APPNP(nn.Module):
h
=
self
.
activation
(
layer
(
h
))
h
=
self
.
layers
[
-
1
](
self
.
feat_drop
(
h
))
# propagation step
h
=
self
.
propagate
(
h
,
self
.
g
)
h
=
self
.
propagate
(
self
.
g
,
h
)
return
h
examples/pytorch/gat/gat.py
View file @
9314aabd
...
...
@@ -49,7 +49,7 @@ class GAT(nn.Module):
def
forward
(
self
,
inputs
):
h
=
inputs
for
l
in
range
(
self
.
num_layers
):
h
=
self
.
gat_layers
[
l
](
h
,
self
.
g
).
flatten
(
1
)
h
=
self
.
gat_layers
[
l
](
self
.
g
,
h
).
flatten
(
1
)
# output projection
logits
=
self
.
gat_layers
[
-
1
](
h
,
self
.
g
).
mean
(
1
)
logits
=
self
.
gat_layers
[
-
1
](
self
.
g
,
h
).
mean
(
1
)
return
logits
examples/pytorch/gcn/gcn.py
View file @
9314aabd
...
...
@@ -35,5 +35,5 @@ class GCN(nn.Module):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
!=
0
:
h
=
self
.
dropout
(
h
)
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
examples/pytorch/gin/gin.py
View file @
9314aabd
...
...
@@ -155,7 +155,7 @@ class GIN(nn.Module):
hidden_rep
=
[
h
]
for
layer
in
range
(
self
.
num_layers
-
1
):
h
=
self
.
ginlayers
[
layer
](
h
,
g
)
h
=
self
.
ginlayers
[
layer
](
g
,
h
)
hidden_rep
.
append
(
h
)
score_over_layer
=
0
...
...
examples/pytorch/graphsage/graphsage.py
View file @
9314aabd
...
...
@@ -41,7 +41,7 @@ class GraphSAGE(nn.Module):
def
forward
(
self
,
features
):
h
=
features
for
layer
in
self
.
layers
:
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
...
...
examples/pytorch/model_zoo/citation_network/conf.py
View file @
9314aabd
...
...
@@ -50,7 +50,7 @@ GIN_CONFIG = {
}
CHEBNET_CONFIG
=
{
'extra_args'
:
[
16
,
1
,
3
,
True
],
'extra_args'
:
[
32
,
1
,
2
,
True
],
'lr'
:
1e-2
,
'weight_decay'
:
5e-4
,
}
examples/pytorch/model_zoo/citation_network/models.py
View file @
9314aabd
...
...
@@ -30,7 +30,7 @@ class GCN(nn.Module):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
!=
0
:
h
=
self
.
dropout
(
h
)
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
...
...
@@ -70,9 +70,9 @@ class GAT(nn.Module):
def
forward
(
self
,
inputs
):
h
=
inputs
for
l
in
range
(
self
.
num_layers
):
h
=
self
.
gat_layers
[
l
](
h
,
self
.
g
).
flatten
(
1
)
h
=
self
.
gat_layers
[
l
](
self
.
g
,
h
).
flatten
(
1
)
# output projection
logits
=
self
.
gat_layers
[
-
1
](
h
,
self
.
g
).
mean
(
1
)
logits
=
self
.
gat_layers
[
-
1
](
self
.
g
,
h
).
mean
(
1
)
return
logits
...
...
@@ -101,7 +101,7 @@ class GraphSAGE(nn.Module):
def
forward
(
self
,
features
):
h
=
features
for
layer
in
self
.
layers
:
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
...
...
@@ -148,7 +148,7 @@ class APPNP(nn.Module):
h
=
self
.
activation
(
layer
(
h
))
h
=
self
.
layers
[
-
1
](
self
.
feat_drop
(
h
))
# propagation step
h
=
self
.
propagate
(
h
,
self
.
g
)
h
=
self
.
propagate
(
self
.
g
,
h
)
return
h
...
...
@@ -178,7 +178,7 @@ class TAGCN(nn.Module):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
!=
0
:
h
=
self
.
dropout
(
h
)
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
...
...
@@ -210,7 +210,7 @@ class AGNN(nn.Module):
def
forward
(
self
,
features
):
h
=
self
.
proj
(
features
)
for
layer
in
self
.
layers
:
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
self
.
cls
(
h
)
...
...
@@ -231,7 +231,7 @@ class SGC(nn.Module):
bias
=
bias
)
def
forward
(
self
,
features
):
return
self
.
net
(
features
,
self
.
g
)
return
self
.
net
(
self
.
g
,
features
)
class
GIN
(
nn
.
Module
):
...
...
@@ -286,7 +286,7 @@ class GIN(nn.Module):
def
forward
(
self
,
features
):
h
=
features
for
layer
in
self
.
layers
:
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
class
ChebNet
(
nn
.
Module
):
...
...
@@ -316,5 +316,5 @@ class ChebNet(nn.Module):
def
forward
(
self
,
features
):
h
=
features
for
layer
in
self
.
layers
:
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
,
[
2
]
)
return
h
\ No newline at end of file
examples/pytorch/sgc/sgc.py
View file @
9314aabd
...
...
@@ -19,7 +19,7 @@ from dgl.nn.pytorch.conv import SGConv
def
evaluate
(
model
,
g
,
features
,
labels
,
mask
):
model
.
eval
()
with
torch
.
no_grad
():
logits
=
model
(
features
,
g
)[
mask
]
# only compute the evaluation set
logits
=
model
(
g
,
features
)[
mask
]
# only compute the evaluation set
labels
=
labels
[
mask
]
_
,
indices
=
torch
.
max
(
logits
,
dim
=
1
)
correct
=
torch
.
sum
(
indices
==
labels
)
...
...
@@ -86,7 +86,7 @@ def main(args):
if
epoch
>=
3
:
t0
=
time
.
time
()
# forward
logits
=
model
(
features
,
g
)
# only compute the train set
logits
=
model
(
g
,
features
)
# only compute the train set
loss
=
loss_fcn
(
logits
[
train_mask
],
labels
[
train_mask
])
optimizer
.
zero_grad
()
...
...
examples/pytorch/sgc/sgc_reddit.py
View file @
9314aabd
...
...
@@ -21,7 +21,7 @@ def normalize(h):
def
evaluate
(
model
,
features
,
graph
,
labels
,
mask
):
model
.
eval
()
with
torch
.
no_grad
():
logits
=
model
(
features
,
graph
)[
mask
]
# only compute the evaluation set
logits
=
model
(
graph
,
features
)[
mask
]
# only compute the evaluation set
labels
=
labels
[
mask
]
_
,
indices
=
torch
.
max
(
logits
,
dim
=
1
)
correct
=
torch
.
sum
(
indices
==
labels
)
...
...
@@ -82,7 +82,7 @@ def main(args):
# define loss closure
def
closure
():
optimizer
.
zero_grad
()
output
=
model
(
features
,
g
)[
train_mask
]
output
=
model
(
g
,
features
)[
train_mask
]
loss_train
=
F
.
cross_entropy
(
output
,
labels
[
train_mask
])
loss_train
.
backward
()
return
loss_train
...
...
examples/pytorch/tagcn/tagcn.py
View file @
9314aabd
...
...
@@ -35,5 +35,5 @@ class TAGCN(nn.Module):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
!=
0
:
h
=
self
.
dropout
(
h
)
h
=
layer
(
h
,
self
.
g
)
h
=
layer
(
self
.
g
,
h
)
return
h
python/dgl/model_zoo/chem/mpnn.py
View file @
9314aabd
...
...
@@ -145,7 +145,7 @@ class MPNNModel(nn.Module):
out
,
h
=
self
.
gru
(
m
.
unsqueeze
(
0
),
h
)
out
=
out
.
squeeze
(
0
)
out
=
self
.
set2set
(
out
,
g
)
out
=
self
.
set2set
(
g
,
out
)
out
=
F
.
relu
(
self
.
lin1
(
out
))
out
=
self
.
lin2
(
out
)
return
out
python/dgl/nn/mxnet/conv.py
View file @
9314aabd
...
...
@@ -83,7 +83,7 @@ class GraphConv(gluon.Block):
self
.
_activation
=
activation
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute graph convolution.
Notes
...
...
@@ -95,10 +95,10 @@ class GraphConv(gluon.Block):
Parameters
----------
feat : mxnet.NDArray
The input feature
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature
Returns
-------
...
...
python/dgl/nn/mxnet/glob.py
View file @
9314aabd
...
...
@@ -19,16 +19,16 @@ class SumPooling(nn.Block):
def
__init__
(
self
):
super
(
SumPooling
,
self
).
__init__
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute sum pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -56,16 +56,16 @@ class AvgPooling(nn.Block):
def
__init__
(
self
):
super
(
AvgPooling
,
self
).
__init__
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute average pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -93,16 +93,16 @@ class MaxPooling(nn.Block):
def
__init__
(
self
):
super
(
MaxPooling
,
self
).
__init__
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute max pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -134,16 +134,16 @@ class SortPooling(nn.Block):
super
(
SortPooling
,
self
).
__init__
()
self
.
k
=
k
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute sort pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -190,16 +190,16 @@ class GlobalAttentionPooling(nn.Block):
self
.
gate_nn
=
gate_nn
self
.
feat_nn
=
feat_nn
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -258,16 +258,16 @@ class Set2Set(nn.Block):
self
.
lstm
=
gluon
.
rnn
.
LSTM
(
self
.
input_dim
,
num_layers
=
n_layers
,
input_size
=
self
.
output_dim
)
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute set2set pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
python/dgl/nn/pytorch/conv.py
View file @
9314aabd
...
...
@@ -107,7 +107,7 @@ class GraphConv(nn.Module):
if
self
.
bias
is
not
None
:
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute graph convolution.
Notes
...
...
@@ -119,10 +119,10 @@ class GraphConv(nn.Module):
Parameters
----------
feat : torch.Tensor
The input feature
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature
Returns
-------
...
...
@@ -246,16 +246,16 @@ class GATConv(nn.Module):
if
isinstance
(
self
.
res_fc
,
nn
.
Linear
):
nn
.
init
.
xavier_normal_
(
self
.
res_fc
.
weight
,
gain
=
gain
)
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute graph attention network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -338,16 +338,16 @@ class TAGConv(nn.Module):
gain
=
nn
.
init
.
calculate_gain
(
'relu'
)
nn
.
init
.
xavier_normal_
(
self
.
lin
.
weight
,
gain
=
gain
)
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute topology adaptive graph convolution.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -643,16 +643,16 @@ class SAGEConv(nn.Module):
_
,
(
rst
,
_
)
=
self
.
lstm
(
m
,
h
)
return
{
'neigh'
:
rst
.
squeeze
(
0
)}
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -742,11 +742,13 @@ class GatedGraphConv(nn.Module):
self
.
gru
.
reset_parameters
()
init
.
xavier_normal_
(
self
.
edge_embed
.
weight
,
gain
=
gain
)
def
forward
(
self
,
feat
,
etypes
,
graph
):
def
forward
(
self
,
graph
,
feat
,
etypes
):
"""Compute Gated Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
...
...
@@ -754,8 +756,6 @@ class GatedGraphConv(nn.Module):
etypes : torch.LongTensor
The edge type tensor of shape :math:`(E,)` where :math:`E` is
the number of edges of the graph.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -856,11 +856,13 @@ class GMMConv(nn.Module):
if
self
.
bias
is
not
None
:
init
.
zeros_
(
self
.
bias
.
data
)
def
forward
(
self
,
feat
,
pseudo
,
graph
):
def
forward
(
self
,
graph
,
feat
,
pseudo
):
"""Compute Gaussian Mixture Model Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
...
...
@@ -869,8 +871,6 @@ class GMMConv(nn.Module):
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}`
is the dimensionality of pseudo coordinate.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -940,18 +940,18 @@ class GINConv(nn.Module):
else
:
self
.
register_buffer
(
'eps'
,
th
.
FloatTensor
([
init_eps
]))
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute Graph Isomorphism Network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D)` where :math:`D`
could be any positive integer, :math:`N` is the number
of nodes. If ``apply_func`` is not None, :math:`D` should
fit the input dimensionality requirement of ``apply_func``.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -1025,16 +1025,22 @@ class ChebConv(nn.Module):
if
module
.
bias
is
not
None
:
init
.
zeros_
(
module
.
bias
)
def
forward
(
self
,
feat
,
graph
,
lambda_max
=
None
):
def
forward
(
self
,
graph
,
feat
,
lambda_max
=
None
):
r
"""Compute ChebNet layer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
lambda_max : list or tensor or None, optional.
A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
Returns
-------
...
...
@@ -1047,13 +1053,13 @@ class ChebConv(nn.Module):
graph
.
in_degrees
().
float
().
clamp
(
min
=
1
),
-
0.5
).
unsqueeze
(
-
1
).
to
(
feat
.
device
)
if
lambda_max
is
None
:
lambda_max
=
laplacian_lambda_max
(
graph
)
if
isinstance
(
lambda_max
,
list
):
lambda_max
=
th
.
Tensor
(
lambda_max
).
to
(
feat
.
device
)
if
lambda_max
.
dim
()
<
1
:
lambda_max
=
lambda_max
.
unsqueeze
(
-
1
)
# (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1)
lambda_max
=
broadcast_nodes
(
graph
,
lambda_max
)
# T0(X)
Tx_0
=
feat
rst
=
self
.
fc
[
0
](
Tx_0
)
# T1(X)
...
...
@@ -1125,16 +1131,16 @@ class SGConv(nn.Module):
self
.
_k
=
k
self
.
norm
=
norm
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute Simplifying Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -1241,11 +1247,13 @@ class NNConv(nn.Module):
if
isinstance
(
self
.
res_fc
,
nn
.
Linear
):
nn
.
init
.
xavier_normal_
(
self
.
res_fc
.
weight
,
gain
=
gain
)
def
forward
(
self
,
feat
,
efeat
,
graph
):
def
forward
(
self
,
graph
,
feat
,
efeat
):
r
"""Compute MPNN Graph Convolution layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the
...
...
@@ -1253,8 +1261,6 @@ class NNConv(nn.Module):
efeat : torch.Tensor
The edge feature of shape :math:`(N, *)`, should fit the input
shape requirement of ``edge_nn``.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -1309,16 +1315,16 @@ class APPNPConv(nn.Module):
self
.
_alpha
=
alpha
self
.
edge_drop
=
nn
.
Dropout
(
edge_drop
)
if
edge_drop
>
0
else
Identity
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute APPNP layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -1374,16 +1380,16 @@ class AGNNConv(nn.Module):
else
:
self
.
register_buffer
(
'beta'
,
th
.
Tensor
([
init_beta
]))
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute AGNN layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, *)` :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -1452,18 +1458,18 @@ class DenseGraphConv(nn.Module):
if
self
.
bias
is
not
None
:
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
feat
,
adj
):
def
forward
(
self
,
adj
,
feat
):
r
"""Compute (Dense) Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
...
...
@@ -1549,18 +1555,18 @@ class DenseSAGEConv(nn.Module):
gain
=
nn
.
init
.
calculate_gain
(
'relu'
)
nn
.
init
.
xavier_uniform_
(
self
.
fc
.
weight
,
gain
=
gain
)
def
forward
(
self
,
feat
,
adj
):
def
forward
(
self
,
adj
,
feat
):
r
"""Compute (Dense) Graph SAGE layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
...
...
@@ -1629,18 +1635,21 @@ class DenseChebConv(nn.Module):
for
i
in
range
(
self
.
_k
):
init
.
xavier_normal_
(
self
.
W
[
i
],
init
.
calculate_gain
(
'relu'
))
def
forward
(
self
,
feat
,
adj
):
def
forward
(
self
,
adj
,
feat
,
lambda_max
=
None
):
r
"""Compute (Dense) Chebyshev Spectral Graph Convolution layer.
Parameters
----------
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
adj : torch.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
lambda_max : float or None, optional
A float value indicates the largest eigenvalue of given graph.
Default: None.
Returns
-------
...
...
@@ -1656,10 +1665,11 @@ class DenseChebConv(nn.Module):
I
=
th
.
eye
(
num_nodes
).
to
(
A
)
L
=
I
-
D_invsqrt
@
A
@
D_invsqrt
if
lambda_max
is
None
:
lambda_
=
th
.
eig
(
L
)[
0
][:,
0
]
lambda_max
=
lambda_
.
max
()
L_hat
=
2
*
L
/
lambda_max
-
I
L_hat
=
2
*
L
/
lambda_max
-
I
Z
=
[
th
.
eye
(
num_nodes
).
to
(
A
)]
for
i
in
range
(
1
,
self
.
_k
):
if
i
==
1
:
...
...
python/dgl/nn/pytorch/glob.py
View file @
9314aabd
...
...
@@ -23,17 +23,17 @@ class SumPooling(nn.Module):
def
__init__
(
self
):
super
(
SumPooling
,
self
).
__init__
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute sum pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -57,16 +57,16 @@ class AvgPooling(nn.Module):
def
__init__
(
self
):
super
(
AvgPooling
,
self
).
__init__
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute average pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -90,16 +90,16 @@ class MaxPooling(nn.Module):
def
__init__
(
self
):
super
(
MaxPooling
,
self
).
__init__
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute max pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -127,16 +127,16 @@ class SortPooling(nn.Module):
super
(
SortPooling
,
self
).
__init__
()
self
.
k
=
k
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute sort pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -179,16 +179,16 @@ class GlobalAttentionPooling(nn.Module):
self
.
gate_nn
=
gate_nn
self
.
feat_nn
=
feat_nn
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph
The graph.
Returns
-------
...
...
@@ -252,16 +252,16 @@ class Set2Set(nn.Module):
"""Reinitialize learnable parameters."""
self
.
lstm
.
reset_parameters
()
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute set2set pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -568,17 +568,17 @@ class SetTransformerEncoder(nn.Module):
self
.
layers
=
nn
.
ModuleList
(
layers
)
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
"""
Compute the Encoder part of Set Transformer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
@@ -634,17 +634,17 @@ class SetTransformerDecoder(nn.Module):
self
.
layers
=
nn
.
ModuleList
(
layers
)
def
forward
(
self
,
feat
,
graph
):
def
forward
(
self
,
graph
,
feat
):
"""
Compute the decoder part of Set Transformer.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
...
...
tests/mxnet/test_nn.py
View file @
9314aabd
...
...
@@ -24,13 +24,13 @@ def test_graph_conv():
conv
.
initialize
(
ctx
=
ctx
)
# test#1: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
check_close
(
h1
,
_AXWb
(
adj
,
h0
,
conv
.
weight
,
conv
.
bias
))
# test#2: more-dim
h0
=
F
.
ones
((
3
,
5
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
check_close
(
h1
,
_AXWb
(
adj
,
h0
,
conv
.
weight
,
conv
.
bias
))
...
...
@@ -40,12 +40,12 @@ def test_graph_conv():
# test#3: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
# test#4: basic
h0
=
F
.
ones
((
3
,
5
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
...
...
@@ -55,18 +55,18 @@ def test_graph_conv():
with
autograd
.
train_mode
():
# test#3: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
# test#4: basic
h0
=
F
.
ones
((
3
,
5
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
# test not override features
g
.
ndata
[
"h"
]
=
2
*
F
.
ones
((
3
,
1
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
1
assert
len
(
g
.
edata
)
==
0
assert
"h"
in
g
.
ndata
...
...
@@ -82,13 +82,13 @@ def test_set2set():
# test#1: basic
h0
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
h1
=
s2s
(
h0
,
g
)
h1
=
s2s
(
g
,
h0
)
assert
h1
.
shape
[
0
]
==
10
and
h1
.
ndim
==
1
# test#2: batched graph
bg
=
dgl
.
batch
([
g
,
g
,
g
])
h0
=
F
.
randn
((
bg
.
number_of_nodes
(),
5
))
h1
=
s2s
(
h0
,
bg
)
h1
=
s2s
(
bg
,
h0
)
assert
h1
.
shape
[
0
]
==
3
and
h1
.
shape
[
1
]
==
10
and
h1
.
ndim
==
2
def
test_glob_att_pool
():
...
...
@@ -100,13 +100,13 @@ def test_glob_att_pool():
print
(
gap
)
# test#1: basic
h0
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
h1
=
gap
(
h0
,
g
)
h1
=
gap
(
g
,
h0
)
assert
h1
.
shape
[
0
]
==
10
and
h1
.
ndim
==
1
# test#2: batched graph
bg
=
dgl
.
batch
([
g
,
g
,
g
,
g
])
h0
=
F
.
randn
((
bg
.
number_of_nodes
(),
5
))
h1
=
gap
(
h0
,
bg
)
h1
=
gap
(
bg
,
h0
)
assert
h1
.
shape
[
0
]
==
4
and
h1
.
shape
[
1
]
==
10
and
h1
.
ndim
==
2
def
test_simple_pool
():
...
...
@@ -120,20 +120,20 @@ def test_simple_pool():
# test#1: basic
h0
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
h1
=
sum_pool
(
h0
,
g
)
h1
=
sum_pool
(
g
,
h0
)
check_close
(
h1
,
F
.
sum
(
h0
,
0
))
h1
=
avg_pool
(
h0
,
g
)
h1
=
avg_pool
(
g
,
h0
)
check_close
(
h1
,
F
.
mean
(
h0
,
0
))
h1
=
max_pool
(
h0
,
g
)
h1
=
max_pool
(
g
,
h0
)
check_close
(
h1
,
F
.
max
(
h0
,
0
))
h1
=
sort_pool
(
h0
,
g
)
h1
=
sort_pool
(
g
,
h0
)
assert
h1
.
shape
[
0
]
==
10
*
5
and
h1
.
ndim
==
1
# test#2: batched graph
g_
=
dgl
.
DGLGraph
(
nx
.
path_graph
(
5
))
bg
=
dgl
.
batch
([
g
,
g_
,
g
,
g_
,
g
])
h0
=
F
.
randn
((
bg
.
number_of_nodes
(),
5
))
h1
=
sum_pool
(
h0
,
bg
)
h1
=
sum_pool
(
bg
,
h0
)
truth
=
mx
.
nd
.
stack
(
F
.
sum
(
h0
[:
15
],
0
),
F
.
sum
(
h0
[
15
:
20
],
0
),
F
.
sum
(
h0
[
20
:
35
],
0
),
...
...
@@ -141,7 +141,7 @@ def test_simple_pool():
F
.
sum
(
h0
[
40
:
55
],
0
),
axis
=
0
)
check_close
(
h1
,
truth
)
h1
=
avg_pool
(
h0
,
bg
)
h1
=
avg_pool
(
bg
,
h0
)
truth
=
mx
.
nd
.
stack
(
F
.
mean
(
h0
[:
15
],
0
),
F
.
mean
(
h0
[
15
:
20
],
0
),
F
.
mean
(
h0
[
20
:
35
],
0
),
...
...
@@ -149,7 +149,7 @@ def test_simple_pool():
F
.
mean
(
h0
[
40
:
55
],
0
),
axis
=
0
)
check_close
(
h1
,
truth
)
h1
=
max_pool
(
h0
,
bg
)
h1
=
max_pool
(
bg
,
h0
)
truth
=
mx
.
nd
.
stack
(
F
.
max
(
h0
[:
15
],
0
),
F
.
max
(
h0
[
15
:
20
],
0
),
F
.
max
(
h0
[
20
:
35
],
0
),
...
...
@@ -157,7 +157,7 @@ def test_simple_pool():
F
.
max
(
h0
[
40
:
55
],
0
),
axis
=
0
)
check_close
(
h1
,
truth
)
h1
=
sort_pool
(
h0
,
bg
)
h1
=
sort_pool
(
bg
,
h0
)
assert
h1
.
shape
[
0
]
==
5
and
h1
.
shape
[
1
]
==
10
*
5
and
h1
.
ndim
==
2
def
uniform_attention
(
g
,
shape
):
...
...
tests/pytorch/test_nn.py
View file @
9314aabd
...
...
@@ -24,13 +24,13 @@ def test_graph_conv():
print
(
conv
)
# test#1: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
assert
F
.
allclose
(
h1
,
_AXWb
(
adj
,
h0
,
conv
.
weight
,
conv
.
bias
))
# test#2: more-dim
h0
=
F
.
ones
((
3
,
5
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
assert
F
.
allclose
(
h1
,
_AXWb
(
adj
,
h0
,
conv
.
weight
,
conv
.
bias
))
...
...
@@ -40,12 +40,12 @@ def test_graph_conv():
conv
=
conv
.
to
(
ctx
)
# test#3: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
# test#4: basic
h0
=
F
.
ones
((
3
,
5
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
...
...
@@ -54,12 +54,12 @@ def test_graph_conv():
conv
=
conv
.
to
(
ctx
)
# test#3: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
# test#4: basic
h0
=
F
.
ones
((
3
,
5
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
...
...
@@ -94,7 +94,7 @@ def test_tagconv():
# test#1: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
0
shp
=
norm
.
shape
+
(
1
,)
*
(
h0
.
dim
()
-
1
)
...
...
@@ -107,7 +107,7 @@ def test_tagconv():
conv
=
conv
.
to
(
ctx
)
# test#2: basic
h0
=
F
.
ones
((
3
,
5
))
h1
=
conv
(
h0
,
g
)
h1
=
conv
(
g
,
h0
)
assert
h1
.
shape
[
-
1
]
==
2
# test reset_parameters
...
...
@@ -127,7 +127,7 @@ def test_set2set():
# test#1: basic
h0
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
h1
=
s2s
(
h0
,
g
)
h1
=
s2s
(
g
,
h0
)
assert
h1
.
shape
[
0
]
==
10
and
h1
.
dim
()
==
1
# test#2: batched graph
...
...
@@ -135,7 +135,7 @@ def test_set2set():
g2
=
dgl
.
DGLGraph
(
nx
.
path_graph
(
5
))
bg
=
dgl
.
batch
([
g
,
g1
,
g2
])
h0
=
F
.
randn
((
bg
.
number_of_nodes
(),
5
))
h1
=
s2s
(
h0
,
bg
)
h1
=
s2s
(
bg
,
h0
)
assert
h1
.
shape
[
0
]
==
3
and
h1
.
shape
[
1
]
==
10
and
h1
.
dim
()
==
2
def
test_glob_att_pool
():
...
...
@@ -149,13 +149,13 @@ def test_glob_att_pool():
# test#1: basic
h0
=
F
.
randn
((
g
.
number_of_nodes
(),
5
))
h1
=
gap
(
h0
,
g
)
h1
=
gap
(
g
,
h0
)
assert
h1
.
shape
[
0
]
==
10
and
h1
.
dim
()
==
1
# test#2: batched graph
bg
=
dgl
.
batch
([
g
,
g
,
g
,
g
])
h0
=
F
.
randn
((
bg
.
number_of_nodes
(),
5
))
h1
=
gap
(
h0
,
bg
)
h1
=
gap
(
bg
,
h0
)
assert
h1
.
shape
[
0
]
==
4
and
h1
.
shape
[
1
]
==
10
and
h1
.
dim
()
==
2
def
test_simple_pool
():
...
...
@@ -176,13 +176,13 @@ def test_simple_pool():
max_pool
=
max_pool
.
to
(
ctx
)
sort_pool
=
sort_pool
.
to
(
ctx
)
h0
=
h0
.
to
(
ctx
)
h1
=
sum_pool
(
h0
,
g
)
h1
=
sum_pool
(
g
,
h0
)
assert
F
.
allclose
(
h1
,
F
.
sum
(
h0
,
0
))
h1
=
avg_pool
(
h0
,
g
)
h1
=
avg_pool
(
g
,
h0
)
assert
F
.
allclose
(
h1
,
F
.
mean
(
h0
,
0
))
h1
=
max_pool
(
h0
,
g
)
h1
=
max_pool
(
g
,
h0
)
assert
F
.
allclose
(
h1
,
F
.
max
(
h0
,
0
))
h1
=
sort_pool
(
h0
,
g
)
h1
=
sort_pool
(
g
,
h0
)
assert
h1
.
shape
[
0
]
==
10
*
5
and
h1
.
dim
()
==
1
# test#2: batched graph
...
...
@@ -192,7 +192,7 @@ def test_simple_pool():
if
F
.
gpu_ctx
():
h0
=
h0
.
to
(
ctx
)
h1
=
sum_pool
(
h0
,
bg
)
h1
=
sum_pool
(
bg
,
h0
)
truth
=
th
.
stack
([
F
.
sum
(
h0
[:
15
],
0
),
F
.
sum
(
h0
[
15
:
20
],
0
),
F
.
sum
(
h0
[
20
:
35
],
0
),
...
...
@@ -200,7 +200,7 @@ def test_simple_pool():
F
.
sum
(
h0
[
40
:
55
],
0
)],
0
)
assert
F
.
allclose
(
h1
,
truth
)
h1
=
avg_pool
(
h0
,
bg
)
h1
=
avg_pool
(
bg
,
h0
)
truth
=
th
.
stack
([
F
.
mean
(
h0
[:
15
],
0
),
F
.
mean
(
h0
[
15
:
20
],
0
),
F
.
mean
(
h0
[
20
:
35
],
0
),
...
...
@@ -208,7 +208,7 @@ def test_simple_pool():
F
.
mean
(
h0
[
40
:
55
],
0
)],
0
)
assert
F
.
allclose
(
h1
,
truth
)
h1
=
max_pool
(
h0
,
bg
)
h1
=
max_pool
(
bg
,
h0
)
truth
=
th
.
stack
([
F
.
max
(
h0
[:
15
],
0
),
F
.
max
(
h0
[
15
:
20
],
0
),
F
.
max
(
h0
[
20
:
35
],
0
),
...
...
@@ -216,7 +216,7 @@ def test_simple_pool():
F
.
max
(
h0
[
40
:
55
],
0
)],
0
)
assert
F
.
allclose
(
h1
,
truth
)
h1
=
sort_pool
(
h0
,
bg
)
h1
=
sort_pool
(
bg
,
h0
)
assert
h1
.
shape
[
0
]
==
5
and
h1
.
shape
[
1
]
==
10
*
5
and
h1
.
dim
()
==
2
def
test_set_trans
():
...
...
@@ -234,11 +234,11 @@ def test_set_trans():
# test#1: basic
h0
=
F
.
randn
((
g
.
number_of_nodes
(),
50
))
h1
=
st_enc_0
(
h0
,
g
)
h1
=
st_enc_0
(
g
,
h0
)
assert
h1
.
shape
==
h0
.
shape
h1
=
st_enc_1
(
h0
,
g
)
h1
=
st_enc_1
(
g
,
h0
)
assert
h1
.
shape
==
h0
.
shape
h2
=
st_dec
(
h1
,
g
)
h2
=
st_dec
(
g
,
h1
)
assert
h2
.
shape
[
0
]
==
200
and
h2
.
dim
()
==
1
# test#2: batched graph
...
...
@@ -246,12 +246,12 @@ def test_set_trans():
g2
=
dgl
.
DGLGraph
(
nx
.
path_graph
(
10
))
bg
=
dgl
.
batch
([
g
,
g1
,
g2
])
h0
=
F
.
randn
((
bg
.
number_of_nodes
(),
50
))
h1
=
st_enc_0
(
h0
,
bg
)
h1
=
st_enc_0
(
bg
,
h0
)
assert
h1
.
shape
==
h0
.
shape
h1
=
st_enc_1
(
h0
,
bg
)
h1
=
st_enc_1
(
bg
,
h0
)
assert
h1
.
shape
==
h0
.
shape
h2
=
st_dec
(
h1
,
bg
)
h2
=
st_dec
(
bg
,
h1
)
assert
h2
.
shape
[
0
]
==
3
and
h2
.
shape
[
1
]
==
200
and
h2
.
dim
()
==
2
def
uniform_attention
(
g
,
shape
):
...
...
@@ -375,7 +375,7 @@ def test_gat_conv():
gat
=
gat
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
h
=
gat
(
feat
,
g
)
h
=
gat
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
2
and
h
.
shape
[
-
2
]
==
4
def
test_sage_conv
():
...
...
@@ -389,7 +389,7 @@ def test_sage_conv():
sage
=
sage
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
h
=
sage
(
feat
,
g
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
def
test_sgc_conv
():
...
...
@@ -403,7 +403,7 @@ def test_sgc_conv():
sgc
=
sgc
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
h
=
sgc
(
feat
,
g
)
h
=
sgc
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
# cached
...
...
@@ -412,8 +412,8 @@ def test_sgc_conv():
if
F
.
gpu_ctx
():
sgc
=
sgc
.
to
(
ctx
)
h_0
=
sgc
(
feat
,
g
)
h_1
=
sgc
(
feat
+
1
,
g
)
h_0
=
sgc
(
g
,
feat
)
h_1
=
sgc
(
g
,
feat
+
1
)
assert
F
.
allclose
(
h_0
,
h_1
)
assert
h_0
.
shape
[
-
1
]
==
10
...
...
@@ -427,7 +427,7 @@ def test_appnp_conv():
appnp
=
appnp
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
h
=
appnp
(
feat
,
g
)
h
=
appnp
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
5
def
test_gin_conv
():
...
...
@@ -444,7 +444,7 @@ def test_gin_conv():
gin
=
gin
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
h
=
gin
(
feat
,
g
)
h
=
gin
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
12
def
test_agnn_conv
():
...
...
@@ -457,7 +457,7 @@ def test_agnn_conv():
agnn
=
agnn
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
h
=
agnn
(
feat
,
g
)
h
=
agnn
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
5
def
test_gated_graph_conv
():
...
...
@@ -472,7 +472,7 @@ def test_gated_graph_conv():
feat
=
feat
.
to
(
ctx
)
etypes
=
etypes
.
to
(
ctx
)
h
=
ggconv
(
feat
,
etypes
,
g
)
h
=
ggconv
(
g
,
feat
,
etypes
)
# current we only do shape check
assert
h
.
shape
[
-
1
]
==
10
...
...
@@ -489,7 +489,7 @@ def test_nn_conv():
feat
=
feat
.
to
(
ctx
)
efeat
=
efeat
.
to
(
ctx
)
h
=
nnconv
(
feat
,
efeat
,
g
)
h
=
nnconv
(
g
,
feat
,
efeat
)
# currently we only do shape check
assert
h
.
shape
[
-
1
]
==
10
...
...
@@ -505,7 +505,7 @@ def test_gmm_conv():
feat
=
feat
.
to
(
ctx
)
pseudo
=
pseudo
.
to
(
ctx
)
h
=
gmmconv
(
feat
,
pseudo
,
g
)
h
=
gmmconv
(
g
,
feat
,
pseudo
)
# currently we only do shape check
assert
h
.
shape
[
-
1
]
==
10
...
...
@@ -523,8 +523,8 @@ def test_dense_graph_conv():
dense_conv
=
dense_conv
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
out_conv
=
conv
(
feat
,
g
)
out_dense_conv
=
dense_conv
(
feat
,
adj
)
out_conv
=
conv
(
g
,
feat
)
out_dense_conv
=
dense_conv
(
adj
,
feat
)
assert
F
.
allclose
(
out_conv
,
out_dense_conv
)
def
test_dense_sage_conv
():
...
...
@@ -541,8 +541,8 @@ def test_dense_sage_conv():
dense_sage
=
dense_sage
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
out_sage
=
sage
(
feat
,
g
)
out_dense_sage
=
dense_sage
(
feat
,
adj
)
out_sage
=
sage
(
g
,
feat
)
out_dense_sage
=
dense_sage
(
adj
,
feat
)
assert
F
.
allclose
(
out_sage
,
out_dense_sage
)
def
test_dense_cheb_conv
():
...
...
@@ -562,8 +562,8 @@ def test_dense_cheb_conv():
dense_cheb
=
dense_cheb
.
to
(
ctx
)
feat
=
feat
.
to
(
ctx
)
out_cheb
=
cheb
(
feat
,
g
)
out_dense_cheb
=
dense_cheb
(
feat
,
adj
)
out_cheb
=
cheb
(
g
,
feat
,
[
2.0
]
)
out_dense_cheb
=
dense_cheb
(
adj
,
feat
,
2.0
)
assert
F
.
allclose
(
out_cheb
,
out_dense_cheb
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment