Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4af02022
Unverified
Commit
4af02022
authored
Mar 19, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Mar 19, 2020
Browse files
[Bug] Multiple fixes (#1374)
* multiple fixes * lint * lint x2
parent
0a51dc54
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
195 additions
and
71 deletions
+195
-71
python/dgl/graph.py
python/dgl/graph.py
+24
-0
python/dgl/heterograph.py
python/dgl/heterograph.py
+20
-8
python/dgl/nn/mxnet/conv/sageconv.py
python/dgl/nn/mxnet/conv/sageconv.py
+40
-15
python/dgl/nn/pytorch/conv/sageconv.py
python/dgl/nn/pytorch/conv/sageconv.py
+4
-6
python/dgl/nn/pytorch/softmax.py
python/dgl/nn/pytorch/softmax.py
+1
-1
python/dgl/nn/tensorflow/conv/sageconv.py
python/dgl/nn/tensorflow/conv/sageconv.py
+39
-14
tests/mxnet/test_nn.py
tests/mxnet/test_nn.py
+23
-10
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+29
-17
tests/tensorflow/test_nn.py
tests/tensorflow/test_nn.py
+15
-0
No files found.
python/dgl/graph.py
View file @
4af02022
...
...
@@ -51,6 +51,30 @@ class DGLBaseGraph(object):
"""
return
self
.
_graph
.
number_of_nodes
()
def
number_of_src_nodes
(
self
):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return
self
.
_graph
.
number_of_nodes
()
def
number_of_dst_nodes
(
self
):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return
self
.
_graph
.
number_of_nodes
()
def
__len__
(
self
):
"""Return the number of nodes in the graph."""
return
self
.
number_of_nodes
()
...
...
python/dgl/heterograph.py
View file @
4af02022
...
...
@@ -716,8 +716,12 @@ class DGLHeteroGraph(object):
def
srcdata
(
self
):
"""Return the data view of all nodes in the SRC category.
**Only works if the graph is uni-bipartite and has one node type in the
SRC category.**
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
--------
...
...
@@ -750,8 +754,10 @@ class DGLHeteroGraph(object):
--------
nodes
"""
assert
self
.
is_unibipartite
,
'srcdata is only allowed for uni-bipartite graph.'
assert
len
(
self
.
srctypes
)
==
1
,
'srcdata is only allowed when there is only one SRC type.'
err_msg
=
(
'srcdata is only allowed when there is only one %s type.'
%
(
'SRC'
if
self
.
is_unibipartite
else
'node'
))
assert
len
(
self
.
srctypes
)
==
1
,
err_msg
ntype
=
self
.
srctypes
[
0
]
ntid
=
self
.
get_ntype_id_from_src
(
ntype
)
return
HeteroNodeDataView
(
self
,
ntype
,
ntid
,
ALL
)
...
...
@@ -760,8 +766,12 @@ class DGLHeteroGraph(object):
def
dstdata
(
self
):
"""Return the data view of all destination nodes.
**Only works if the graph is uni-bipartite and has one node type in the
DST category.**
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
--------
...
...
@@ -794,8 +804,10 @@ class DGLHeteroGraph(object):
--------
nodes
"""
assert
self
.
is_unibipartite
,
'dstdata is only allowed for uni-bipartite graph.'
assert
len
(
self
.
dsttypes
)
==
1
,
'dstdata is only allowed when there is only one DST type.'
err_msg
=
(
'dstdata is only allowed when there is only one %s type.'
%
(
'DST'
if
self
.
is_unibipartite
else
'node'
))
assert
len
(
self
.
dsttypes
)
==
1
,
err_msg
ntype
=
self
.
dsttypes
[
0
]
ntid
=
self
.
get_ntype_id_from_dst
(
ntype
)
return
HeteroNodeDataView
(
self
,
ntype
,
ntid
,
ALL
)
...
...
python/dgl/nn/mxnet/conv/sageconv.py
View file @
4af02022
"""MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
math
from
numbers
import
Integral
import
mxnet
as
mx
from
mxnet
import
nd
from
mxnet.gluon
import
nn
...
...
@@ -24,6 +25,14 @@ class SAGEConv(nn.Block):
----------
in_feats : int
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
...
...
@@ -47,7 +56,15 @@ class SAGEConv(nn.Block):
norm
=
None
,
activation
=
None
):
super
(
SAGEConv
,
self
).
__init__
()
self
.
_in_feats
=
in_feats
if
isinstance
(
in_feats
,
tuple
):
self
.
_in_src_feats
=
in_feats
[
0
]
self
.
_in_dst_feats
=
in_feats
[
1
]
elif
isinstance
(
in_feats
,
Integral
):
self
.
_in_src_feats
=
self
.
_in_dst_feats
=
in_feats
else
:
raise
TypeError
(
'in_feats must be either int or pair of ints'
)
self
.
_out_feats
=
out_feats
self
.
_aggre_type
=
aggregator_type
with
self
.
name_scope
():
...
...
@@ -55,18 +72,18 @@ class SAGEConv(nn.Block):
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
self
.
activation
=
activation
if
aggregator_type
==
'pool'
:
self
.
fc_pool
=
nn
.
Dense
(
in
_feats
,
use_bias
=
bias
,
self
.
fc_pool
=
nn
.
Dense
(
self
.
_in_src
_feats
,
use_bias
=
bias
,
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
in_units
=
in
_feats
)
in_units
=
self
.
_in_src
_feats
)
if
aggregator_type
==
'lstm'
:
raise
NotImplementedError
if
aggregator_type
!=
'gcn'
:
self
.
fc_self
=
nn
.
Dense
(
out_feats
,
use_bias
=
bias
,
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
in_units
=
in
_feats
)
in_units
=
self
.
_in_dst
_feats
)
self
.
fc_neigh
=
nn
.
Dense
(
out_feats
,
use_bias
=
bias
,
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
in_units
=
in
_feats
)
in_units
=
self
.
_in_src
_feats
)
def
forward
(
self
,
graph
,
feat
):
r
"""Compute GraphSAGE layer.
...
...
@@ -86,23 +103,31 @@ class SAGEConv(nn.Block):
is size of output feature.
"""
graph
=
graph
.
local_var
()
feat
=
self
.
feat_drop
(
feat
)
h_self
=
feat
if
isinstance
(
feat
,
tuple
):
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
else
:
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
h_self
=
feat_dst
if
self
.
_aggre_type
==
'mean'
:
graph
.
n
data
[
'h'
]
=
feat
graph
.
src
data
[
'h'
]
=
feat
_src
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'gcn'
:
graph
.
ndata
[
'h'
]
=
feat
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
dstdata
[
'h'
]
=
feat_dst
# saame as above if homogeneous
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in degrees
degs
=
graph
.
in_degrees
().
astype
(
feat
.
dtype
)
degs
=
degs
.
as_in_context
(
feat
.
context
)
h_neigh
=
(
graph
.
n
data
[
'neigh'
]
+
graph
.
n
data
[
'h'
])
/
(
degs
.
expand_dims
(
-
1
)
+
1
)
degs
=
graph
.
in_degrees
().
astype
(
feat
_dst
.
dtype
)
degs
=
degs
.
as_in_context
(
feat
_dst
.
context
)
h_neigh
=
(
graph
.
dst
data
[
'neigh'
]
+
graph
.
dst
data
[
'h'
])
/
(
degs
.
expand_dims
(
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
graph
.
n
data
[
'h'
]
=
nd
.
relu
(
self
.
fc_pool
(
feat
))
graph
.
src
data
[
'h'
]
=
nd
.
relu
(
self
.
fc_pool
(
feat
_src
))
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'lstm'
:
raise
NotImplementedError
else
:
...
...
python/dgl/nn/pytorch/conv/sageconv.py
View file @
4af02022
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from
numbers
import
Integral
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
...
...
@@ -124,11 +123,11 @@ class SAGEConv(nn.Module):
"""
graph
=
graph
.
local_var
()
if
torch
.
is_tensor
(
feat
):
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
else
:
if
isinstance
(
feat
,
tuple
):
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
else
:
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
h_self
=
feat_dst
...
...
@@ -141,8 +140,7 @@ class SAGEConv(nn.Module):
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in_degrees
degs
=
graph
.
in_degrees
().
float
()
degs
=
degs
.
to
(
feat_dst
.
device
)
degs
=
graph
.
in_degrees
().
to
(
feat_dst
)
h_neigh
=
(
graph
.
dstdata
[
'neigh'
]
+
graph
.
dstdata
[
'h'
])
/
(
degs
.
unsqueeze
(
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
graph
.
srcdata
[
'h'
]
=
F
.
relu
(
self
.
fc_pool
(
feat_src
))
...
...
python/dgl/nn/pytorch/softmax.py
View file @
4af02022
...
...
@@ -49,7 +49,7 @@ class EdgeSoftmax(th.autograd.Function):
if
not
is_all
(
eids
):
g
=
g
.
edge_subgraph
(
eids
.
long
())
n_nodes
=
g
.
number_of_nodes
()
n_nodes
=
g
.
number_of_
dst_
nodes
()
n_edges
=
g
.
number_of_edges
()
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
...
...
python/dgl/nn/tensorflow/conv/sageconv.py
View file @
4af02022
"""Tensorflow Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from
numbers
import
Integral
import
tensorflow
as
tf
from
tensorflow.keras
import
layers
...
...
@@ -21,8 +22,16 @@ class SAGEConv(layers.Layer):
Parameters
----------
in_feats : int
in_feats : int
, or pair of ints
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
...
...
@@ -47,7 +56,15 @@ class SAGEConv(layers.Layer):
norm
=
None
,
activation
=
None
):
super
(
SAGEConv
,
self
).
__init__
()
self
.
_in_feats
=
in_feats
if
isinstance
(
in_feats
,
tuple
):
self
.
_in_src_feats
=
in_feats
[
0
]
self
.
_in_dst_feats
=
in_feats
[
1
]
elif
isinstance
(
in_feats
,
Integral
):
self
.
_in_src_feats
=
self
.
_in_dst_feats
=
in_feats
else
:
raise
TypeError
(
'in_feats must be either int or pair of ints'
)
self
.
_out_feats
=
out_feats
self
.
_aggre_type
=
aggregator_type
self
.
norm
=
norm
...
...
@@ -55,9 +72,9 @@ class SAGEConv(layers.Layer):
self
.
activation
=
activation
# aggregator type: mean/pool/lstm/gcn
if
aggregator_type
==
'pool'
:
self
.
fc_pool
=
layers
.
Dense
(
in
_feats
)
self
.
fc_pool
=
layers
.
Dense
(
self
.
_in_src
_feats
)
if
aggregator_type
==
'lstm'
:
self
.
lstm
=
layers
.
LSTM
(
units
=
in
_feats
)
self
.
lstm
=
layers
.
LSTM
(
units
=
self
.
_in_src
_feats
)
if
aggregator_type
!=
'gcn'
:
self
.
fc_self
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
self
.
fc_neigh
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
...
...
@@ -89,27 +106,35 @@ class SAGEConv(layers.Layer):
is size of output feature.
"""
graph
=
graph
.
local_var
()
feat
=
self
.
feat_drop
(
feat
)
h_self
=
feat
if
isinstance
(
feat
,
tuple
):
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
else
:
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
h_self
=
feat_dst
if
self
.
_aggre_type
==
'mean'
:
graph
.
n
data
[
'h'
]
=
feat
graph
.
src
data
[
'h'
]
=
feat
_src
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'gcn'
:
graph
.
ndata
[
'h'
]
=
feat
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in_degrees
degs
=
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
)
h_neigh
=
(
graph
.
n
data
[
'neigh'
]
+
graph
.
n
data
[
'h'
]
h_neigh
=
(
graph
.
dst
data
[
'neigh'
]
+
graph
.
dst
data
[
'h'
]
)
/
(
tf
.
expand_dims
(
degs
,
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
graph
.
n
data
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat
))
graph
.
src
data
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat
_src
))
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'lstm'
:
graph
.
n
data
[
'h'
]
=
feat
graph
.
src
data
[
'h'
]
=
feat
_src
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
else
:
raise
KeyError
(
'Aggregator type {} not recognized.'
.
format
(
self
.
_aggre_type
))
...
...
tests/mxnet/test_nn.py
View file @
4af02022
...
...
@@ -127,17 +127,30 @@ def test_gat_conv():
assert
h1
.
shape
==
(
20
,
5
,
20
)
def
test_sage_conv
():
g
=
dgl
.
DGLGraph
(
nx
.
erdos_renyi_graph
(
20
,
0.3
))
ctx
=
F
.
ctx
()
graphsage
=
nn
.
SAGEConv
(
10
,
20
)
graphsage
.
initialize
(
ctx
=
ctx
)
print
(
graphsage
)
for
aggre_type
in
[
'mean'
,
'pool'
,
'gcn'
]:
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
sage
.
initialize
(
ctx
=
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
# test#1: basic
h0
=
F
.
randn
((
20
,
10
))
h1
=
graphsage
(
g
,
h0
)
assert
h1
.
shape
==
(
20
,
20
)
g
=
dgl
.
graph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
))
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
sage
.
initialize
(
ctx
=
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
feat
=
(
F
.
randn
((
100
,
10
)),
F
.
randn
((
200
,
dst_dim
)))
sage
.
initialize
(
ctx
=
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
2
assert
h
.
shape
[
0
]
==
200
def
test_gg_conv
():
g
=
dgl
.
DGLGraph
(
nx
.
erdos_renyi_graph
(
20
,
0.3
))
...
...
tests/pytorch/test_nn.py
View file @
4af02022
...
...
@@ -290,23 +290,28 @@ def test_edge_softmax():
print
(
score
.
grad
[:
10
],
grad_score
[:
10
])
# Test 2
def
generate_rand_graph
(
n
):
arr
=
(
sp
.
sparse
.
random
(
n
,
n
,
density
=
0.1
,
format
=
'coo'
)
!=
0
).
astype
(
np
.
int64
)
return
dgl
.
DGLGraph
(
arr
,
readonly
=
True
)
g
=
generate_rand_graph
(
50
)
a1
=
F
.
randn
((
g
.
number_of_edges
(),
1
)).
requires_grad_
()
a2
=
a1
.
clone
().
detach
().
requires_grad_
()
g
.
edata
[
's'
]
=
a1
g
.
group_apply_edges
(
'dst'
,
lambda
edges
:
{
'ss'
:
F
.
softmax
(
edges
.
data
[
's'
],
1
)})
g
.
edata
[
'ss'
].
sum
().
backward
()
builtin_sm
=
nn
.
edge_softmax
(
g
,
a2
)
builtin_sm
.
sum
().
backward
()
print
(
a1
.
grad
-
a2
.
grad
)
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
2
assert
F
.
allclose
(
a1
.
grad
,
a2
.
grad
,
rtol
=
1e-4
,
atol
=
1e-4
)
# Follow tolerance in unittest backend
def
generate_rand_graph
(
n
,
m
=
None
,
ctor
=
dgl
.
DGLGraph
):
if
m
is
None
:
m
=
n
arr
=
(
sp
.
sparse
.
random
(
m
,
n
,
density
=
0.1
,
format
=
'coo'
)
!=
0
).
astype
(
np
.
int64
)
return
ctor
(
arr
,
readonly
=
True
)
for
g
in
[
generate_rand_graph
(
50
),
generate_rand_graph
(
50
,
ctor
=
dgl
.
graph
),
generate_rand_graph
(
100
,
50
,
ctor
=
dgl
.
bipartite
)]:
a1
=
F
.
randn
((
g
.
number_of_edges
(),
1
)).
requires_grad_
()
a2
=
a1
.
clone
().
detach
().
requires_grad_
()
g
.
edata
[
's'
]
=
a1
g
.
group_apply_edges
(
'dst'
,
lambda
edges
:
{
'ss'
:
F
.
softmax
(
edges
.
data
[
's'
],
1
)})
g
.
edata
[
'ss'
].
sum
().
backward
()
builtin_sm
=
nn
.
edge_softmax
(
g
,
a2
)
builtin_sm
.
sum
().
backward
()
print
(
a1
.
grad
-
a2
.
grad
)
assert
len
(
g
.
srcdata
)
==
0
assert
len
(
g
.
dstdata
)
==
0
assert
len
(
g
.
edata
)
==
2
assert
F
.
allclose
(
a1
.
grad
,
a2
.
grad
,
rtol
=
1e-4
,
atol
=
1e-4
)
# Follow tolerance in unittest backend
def
test_partial_edge_softmax
():
g
=
dgl
.
DGLGraph
()
...
...
@@ -402,6 +407,13 @@ def test_sage_conv():
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
graph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
))
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
sage
=
sage
.
to
(
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
...
...
tests/tensorflow/test_nn.py
View file @
4af02022
...
...
@@ -309,12 +309,27 @@ def test_gat_conv():
def
test_sage_conv
():
for
aggre_type
in
[
'mean'
,
'pool'
,
'gcn'
,
'lstm'
]:
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
graph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
))
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
feat
=
(
F
.
randn
((
100
,
10
)),
F
.
randn
((
200
,
dst_dim
)))
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
2
assert
h
.
shape
[
0
]
==
200
def
test_sgc_conv
():
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment