Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4af02022
Unverified
Commit
4af02022
authored
Mar 19, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Mar 19, 2020
Browse files
[Bug] Multiple fixes (#1374)
* multiple fixes * lint * lint x2
parent
0a51dc54
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
195 additions
and
71 deletions
+195
-71
python/dgl/graph.py
python/dgl/graph.py
+24
-0
python/dgl/heterograph.py
python/dgl/heterograph.py
+20
-8
python/dgl/nn/mxnet/conv/sageconv.py
python/dgl/nn/mxnet/conv/sageconv.py
+40
-15
python/dgl/nn/pytorch/conv/sageconv.py
python/dgl/nn/pytorch/conv/sageconv.py
+4
-6
python/dgl/nn/pytorch/softmax.py
python/dgl/nn/pytorch/softmax.py
+1
-1
python/dgl/nn/tensorflow/conv/sageconv.py
python/dgl/nn/tensorflow/conv/sageconv.py
+39
-14
tests/mxnet/test_nn.py
tests/mxnet/test_nn.py
+23
-10
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+29
-17
tests/tensorflow/test_nn.py
tests/tensorflow/test_nn.py
+15
-0
No files found.
python/dgl/graph.py
View file @
4af02022
...
@@ -51,6 +51,30 @@ class DGLBaseGraph(object):
...
@@ -51,6 +51,30 @@ class DGLBaseGraph(object):
"""
"""
return
self
.
_graph
.
number_of_nodes
()
return
self
.
_graph
.
number_of_nodes
()
def
number_of_src_nodes
(
self
):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return
self
.
_graph
.
number_of_nodes
()
def
number_of_dst_nodes
(
self
):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return
self
.
_graph
.
number_of_nodes
()
def
__len__
(
self
):
def
__len__
(
self
):
"""Return the number of nodes in the graph."""
"""Return the number of nodes in the graph."""
return
self
.
number_of_nodes
()
return
self
.
number_of_nodes
()
...
...
python/dgl/heterograph.py
View file @
4af02022
...
@@ -716,8 +716,12 @@ class DGLHeteroGraph(object):
...
@@ -716,8 +716,12 @@ class DGLHeteroGraph(object):
def
srcdata
(
self
):
def
srcdata
(
self
):
"""Return the data view of all nodes in the SRC category.
"""Return the data view of all nodes in the SRC category.
**Only works if the graph is uni-bipartite and has one node type in the
Only works if the graph is either
SRC category.**
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
Examples
--------
--------
...
@@ -750,8 +754,10 @@ class DGLHeteroGraph(object):
...
@@ -750,8 +754,10 @@ class DGLHeteroGraph(object):
--------
--------
nodes
nodes
"""
"""
assert
self
.
is_unibipartite
,
'srcdata is only allowed for uni-bipartite graph.'
err_msg
=
(
assert
len
(
self
.
srctypes
)
==
1
,
'srcdata is only allowed when there is only one SRC type.'
'srcdata is only allowed when there is only one %s type.'
%
(
'SRC'
if
self
.
is_unibipartite
else
'node'
))
assert
len
(
self
.
srctypes
)
==
1
,
err_msg
ntype
=
self
.
srctypes
[
0
]
ntype
=
self
.
srctypes
[
0
]
ntid
=
self
.
get_ntype_id_from_src
(
ntype
)
ntid
=
self
.
get_ntype_id_from_src
(
ntype
)
return
HeteroNodeDataView
(
self
,
ntype
,
ntid
,
ALL
)
return
HeteroNodeDataView
(
self
,
ntype
,
ntid
,
ALL
)
...
@@ -760,8 +766,12 @@ class DGLHeteroGraph(object):
...
@@ -760,8 +766,12 @@ class DGLHeteroGraph(object):
def
dstdata
(
self
):
def
dstdata
(
self
):
"""Return the data view of all destination nodes.
"""Return the data view of all destination nodes.
**Only works if the graph is uni-bipartite and has one node type in the
Only works if the graph is either
DST category.**
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
Examples
--------
--------
...
@@ -794,8 +804,10 @@ class DGLHeteroGraph(object):
...
@@ -794,8 +804,10 @@ class DGLHeteroGraph(object):
--------
--------
nodes
nodes
"""
"""
assert
self
.
is_unibipartite
,
'dstdata is only allowed for uni-bipartite graph.'
err_msg
=
(
assert
len
(
self
.
dsttypes
)
==
1
,
'dstdata is only allowed when there is only one DST type.'
'dstdata is only allowed when there is only one %s type.'
%
(
'DST'
if
self
.
is_unibipartite
else
'node'
))
assert
len
(
self
.
dsttypes
)
==
1
,
err_msg
ntype
=
self
.
dsttypes
[
0
]
ntype
=
self
.
dsttypes
[
0
]
ntid
=
self
.
get_ntype_id_from_dst
(
ntype
)
ntid
=
self
.
get_ntype_id_from_dst
(
ntype
)
return
HeteroNodeDataView
(
self
,
ntype
,
ntid
,
ALL
)
return
HeteroNodeDataView
(
self
,
ntype
,
ntid
,
ALL
)
...
...
python/dgl/nn/mxnet/conv/sageconv.py
View file @
4af02022
"""MXNet Module for GraphSAGE layer"""
"""MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable= no-member, arguments-differ, invalid-name
import
math
import
math
from
numbers
import
Integral
import
mxnet
as
mx
import
mxnet
as
mx
from
mxnet
import
nd
from
mxnet
import
nd
from
mxnet.gluon
import
nn
from
mxnet.gluon
import
nn
...
@@ -24,6 +25,14 @@ class SAGEConv(nn.Block):
...
@@ -24,6 +25,14 @@ class SAGEConv(nn.Block):
----------
----------
in_feats : int
in_feats : int
Input feature size.
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
out_feats : int
Output feature size.
Output feature size.
feat_drop : float
feat_drop : float
...
@@ -47,7 +56,15 @@ class SAGEConv(nn.Block):
...
@@ -47,7 +56,15 @@ class SAGEConv(nn.Block):
norm
=
None
,
norm
=
None
,
activation
=
None
):
activation
=
None
):
super
(
SAGEConv
,
self
).
__init__
()
super
(
SAGEConv
,
self
).
__init__
()
self
.
_in_feats
=
in_feats
if
isinstance
(
in_feats
,
tuple
):
self
.
_in_src_feats
=
in_feats
[
0
]
self
.
_in_dst_feats
=
in_feats
[
1
]
elif
isinstance
(
in_feats
,
Integral
):
self
.
_in_src_feats
=
self
.
_in_dst_feats
=
in_feats
else
:
raise
TypeError
(
'in_feats must be either int or pair of ints'
)
self
.
_out_feats
=
out_feats
self
.
_out_feats
=
out_feats
self
.
_aggre_type
=
aggregator_type
self
.
_aggre_type
=
aggregator_type
with
self
.
name_scope
():
with
self
.
name_scope
():
...
@@ -55,18 +72,18 @@ class SAGEConv(nn.Block):
...
@@ -55,18 +72,18 @@ class SAGEConv(nn.Block):
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
self
.
activation
=
activation
self
.
activation
=
activation
if
aggregator_type
==
'pool'
:
if
aggregator_type
==
'pool'
:
self
.
fc_pool
=
nn
.
Dense
(
in
_feats
,
use_bias
=
bias
,
self
.
fc_pool
=
nn
.
Dense
(
self
.
_in_src
_feats
,
use_bias
=
bias
,
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
in_units
=
in
_feats
)
in_units
=
self
.
_in_src
_feats
)
if
aggregator_type
==
'lstm'
:
if
aggregator_type
==
'lstm'
:
raise
NotImplementedError
raise
NotImplementedError
if
aggregator_type
!=
'gcn'
:
if
aggregator_type
!=
'gcn'
:
self
.
fc_self
=
nn
.
Dense
(
out_feats
,
use_bias
=
bias
,
self
.
fc_self
=
nn
.
Dense
(
out_feats
,
use_bias
=
bias
,
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
in_units
=
in
_feats
)
in_units
=
self
.
_in_dst
_feats
)
self
.
fc_neigh
=
nn
.
Dense
(
out_feats
,
use_bias
=
bias
,
self
.
fc_neigh
=
nn
.
Dense
(
out_feats
,
use_bias
=
bias
,
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
weight_initializer
=
mx
.
init
.
Xavier
(
magnitude
=
math
.
sqrt
(
2.0
)),
in_units
=
in
_feats
)
in_units
=
self
.
_in_src
_feats
)
def
forward
(
self
,
graph
,
feat
):
def
forward
(
self
,
graph
,
feat
):
r
"""Compute GraphSAGE layer.
r
"""Compute GraphSAGE layer.
...
@@ -86,23 +103,31 @@ class SAGEConv(nn.Block):
...
@@ -86,23 +103,31 @@ class SAGEConv(nn.Block):
is size of output feature.
is size of output feature.
"""
"""
graph
=
graph
.
local_var
()
graph
=
graph
.
local_var
()
feat
=
self
.
feat_drop
(
feat
)
h_self
=
feat
if
isinstance
(
feat
,
tuple
):
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
else
:
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
h_self
=
feat_dst
if
self
.
_aggre_type
==
'mean'
:
if
self
.
_aggre_type
==
'mean'
:
graph
.
n
data
[
'h'
]
=
feat
graph
.
src
data
[
'h'
]
=
feat
_src
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'gcn'
:
elif
self
.
_aggre_type
==
'gcn'
:
graph
.
ndata
[
'h'
]
=
feat
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
dstdata
[
'h'
]
=
feat_dst
# saame as above if homogeneous
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in degrees
# divide in degrees
degs
=
graph
.
in_degrees
().
astype
(
feat
.
dtype
)
degs
=
graph
.
in_degrees
().
astype
(
feat
_dst
.
dtype
)
degs
=
degs
.
as_in_context
(
feat
.
context
)
degs
=
degs
.
as_in_context
(
feat
_dst
.
context
)
h_neigh
=
(
graph
.
n
data
[
'neigh'
]
+
graph
.
n
data
[
'h'
])
/
(
degs
.
expand_dims
(
-
1
)
+
1
)
h_neigh
=
(
graph
.
dst
data
[
'neigh'
]
+
graph
.
dst
data
[
'h'
])
/
(
degs
.
expand_dims
(
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
elif
self
.
_aggre_type
==
'pool'
:
graph
.
n
data
[
'h'
]
=
nd
.
relu
(
self
.
fc_pool
(
feat
))
graph
.
src
data
[
'h'
]
=
nd
.
relu
(
self
.
fc_pool
(
feat
_src
))
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'lstm'
:
elif
self
.
_aggre_type
==
'lstm'
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
...
...
python/dgl/nn/pytorch/conv/sageconv.py
View file @
4af02022
"""Torch Module for GraphSAGE layer"""
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable= no-member, arguments-differ, invalid-name
from
numbers
import
Integral
from
numbers
import
Integral
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
...
@@ -124,11 +123,11 @@ class SAGEConv(nn.Module):
...
@@ -124,11 +123,11 @@ class SAGEConv(nn.Module):
"""
"""
graph
=
graph
.
local_var
()
graph
=
graph
.
local_var
()
if
torch
.
is_tensor
(
feat
):
if
isinstance
(
feat
,
tuple
):
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
else
:
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
else
:
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
h_self
=
feat_dst
h_self
=
feat_dst
...
@@ -141,8 +140,7 @@ class SAGEConv(nn.Module):
...
@@ -141,8 +140,7 @@ class SAGEConv(nn.Module):
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in_degrees
# divide in_degrees
degs
=
graph
.
in_degrees
().
float
()
degs
=
graph
.
in_degrees
().
to
(
feat_dst
)
degs
=
degs
.
to
(
feat_dst
.
device
)
h_neigh
=
(
graph
.
dstdata
[
'neigh'
]
+
graph
.
dstdata
[
'h'
])
/
(
degs
.
unsqueeze
(
-
1
)
+
1
)
h_neigh
=
(
graph
.
dstdata
[
'neigh'
]
+
graph
.
dstdata
[
'h'
])
/
(
degs
.
unsqueeze
(
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
elif
self
.
_aggre_type
==
'pool'
:
graph
.
srcdata
[
'h'
]
=
F
.
relu
(
self
.
fc_pool
(
feat_src
))
graph
.
srcdata
[
'h'
]
=
F
.
relu
(
self
.
fc_pool
(
feat_src
))
...
...
python/dgl/nn/pytorch/softmax.py
View file @
4af02022
...
@@ -49,7 +49,7 @@ class EdgeSoftmax(th.autograd.Function):
...
@@ -49,7 +49,7 @@ class EdgeSoftmax(th.autograd.Function):
if
not
is_all
(
eids
):
if
not
is_all
(
eids
):
g
=
g
.
edge_subgraph
(
eids
.
long
())
g
=
g
.
edge_subgraph
(
eids
.
long
())
n_nodes
=
g
.
number_of_nodes
()
n_nodes
=
g
.
number_of_
dst_
nodes
()
n_edges
=
g
.
number_of_edges
()
n_edges
=
g
.
number_of_edges
()
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
...
...
python/dgl/nn/tensorflow/conv/sageconv.py
View file @
4af02022
"""Tensorflow Module for GraphSAGE layer"""
"""Tensorflow Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable= no-member, arguments-differ, invalid-name
from
numbers
import
Integral
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.keras
import
layers
from
tensorflow.keras
import
layers
...
@@ -21,8 +22,16 @@ class SAGEConv(layers.Layer):
...
@@ -21,8 +22,16 @@ class SAGEConv(layers.Layer):
Parameters
Parameters
----------
----------
in_feats : int
in_feats : int
, or pair of ints
Input feature size.
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
out_feats : int
Output feature size.
Output feature size.
feat_drop : float
feat_drop : float
...
@@ -47,7 +56,15 @@ class SAGEConv(layers.Layer):
...
@@ -47,7 +56,15 @@ class SAGEConv(layers.Layer):
norm
=
None
,
norm
=
None
,
activation
=
None
):
activation
=
None
):
super
(
SAGEConv
,
self
).
__init__
()
super
(
SAGEConv
,
self
).
__init__
()
self
.
_in_feats
=
in_feats
if
isinstance
(
in_feats
,
tuple
):
self
.
_in_src_feats
=
in_feats
[
0
]
self
.
_in_dst_feats
=
in_feats
[
1
]
elif
isinstance
(
in_feats
,
Integral
):
self
.
_in_src_feats
=
self
.
_in_dst_feats
=
in_feats
else
:
raise
TypeError
(
'in_feats must be either int or pair of ints'
)
self
.
_out_feats
=
out_feats
self
.
_out_feats
=
out_feats
self
.
_aggre_type
=
aggregator_type
self
.
_aggre_type
=
aggregator_type
self
.
norm
=
norm
self
.
norm
=
norm
...
@@ -55,9 +72,9 @@ class SAGEConv(layers.Layer):
...
@@ -55,9 +72,9 @@ class SAGEConv(layers.Layer):
self
.
activation
=
activation
self
.
activation
=
activation
# aggregator type: mean/pool/lstm/gcn
# aggregator type: mean/pool/lstm/gcn
if
aggregator_type
==
'pool'
:
if
aggregator_type
==
'pool'
:
self
.
fc_pool
=
layers
.
Dense
(
in
_feats
)
self
.
fc_pool
=
layers
.
Dense
(
self
.
_in_src
_feats
)
if
aggregator_type
==
'lstm'
:
if
aggregator_type
==
'lstm'
:
self
.
lstm
=
layers
.
LSTM
(
units
=
in
_feats
)
self
.
lstm
=
layers
.
LSTM
(
units
=
self
.
_in_src
_feats
)
if
aggregator_type
!=
'gcn'
:
if
aggregator_type
!=
'gcn'
:
self
.
fc_self
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
self
.
fc_self
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
self
.
fc_neigh
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
self
.
fc_neigh
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
...
@@ -89,27 +106,35 @@ class SAGEConv(layers.Layer):
...
@@ -89,27 +106,35 @@ class SAGEConv(layers.Layer):
is size of output feature.
is size of output feature.
"""
"""
graph
=
graph
.
local_var
()
graph
=
graph
.
local_var
()
feat
=
self
.
feat_drop
(
feat
)
h_self
=
feat
if
isinstance
(
feat
,
tuple
):
feat_src
=
self
.
feat_drop
(
feat
[
0
])
feat_dst
=
self
.
feat_drop
(
feat
[
1
])
else
:
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
h_self
=
feat_dst
if
self
.
_aggre_type
==
'mean'
:
if
self
.
_aggre_type
==
'mean'
:
graph
.
n
data
[
'h'
]
=
feat
graph
.
src
data
[
'h'
]
=
feat
_src
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'gcn'
:
elif
self
.
_aggre_type
==
'gcn'
:
graph
.
ndata
[
'h'
]
=
feat
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in_degrees
# divide in_degrees
degs
=
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
)
degs
=
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
)
h_neigh
=
(
graph
.
n
data
[
'neigh'
]
+
graph
.
n
data
[
'h'
]
h_neigh
=
(
graph
.
dst
data
[
'neigh'
]
+
graph
.
dst
data
[
'h'
]
)
/
(
tf
.
expand_dims
(
degs
,
-
1
)
+
1
)
)
/
(
tf
.
expand_dims
(
degs
,
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
elif
self
.
_aggre_type
==
'pool'
:
graph
.
n
data
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat
))
graph
.
src
data
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat
_src
))
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
elif
self
.
_aggre_type
==
'lstm'
:
elif
self
.
_aggre_type
==
'lstm'
:
graph
.
n
data
[
'h'
]
=
feat
graph
.
src
data
[
'h'
]
=
feat
_src
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
graph
.
update_all
(
fn
.
copy_src
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
h_neigh
=
graph
.
n
data
[
'neigh'
]
h_neigh
=
graph
.
dst
data
[
'neigh'
]
else
:
else
:
raise
KeyError
(
raise
KeyError
(
'Aggregator type {} not recognized.'
.
format
(
self
.
_aggre_type
))
'Aggregator type {} not recognized.'
.
format
(
self
.
_aggre_type
))
...
...
tests/mxnet/test_nn.py
View file @
4af02022
...
@@ -127,17 +127,30 @@ def test_gat_conv():
...
@@ -127,17 +127,30 @@ def test_gat_conv():
assert
h1
.
shape
==
(
20
,
5
,
20
)
assert
h1
.
shape
==
(
20
,
5
,
20
)
def
test_sage_conv
():
def
test_sage_conv
():
g
=
dgl
.
DGLGraph
(
nx
.
erdos_renyi_graph
(
20
,
0.3
))
for
aggre_type
in
[
'mean'
,
'pool'
,
'gcn'
]:
ctx
=
F
.
ctx
()
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
graphsage
=
nn
.
SAGEConv
(
10
,
20
)
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
graphsage
.
initialize
(
ctx
=
ctx
)
feat
=
F
.
randn
((
100
,
5
))
print
(
graphsage
)
sage
.
initialize
(
ctx
=
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
# test#1: basic
g
=
dgl
.
graph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
))
h0
=
F
.
randn
((
20
,
10
))
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
h1
=
graphsage
(
g
,
h0
)
feat
=
F
.
randn
((
100
,
5
))
assert
h1
.
shape
==
(
20
,
20
)
sage
.
initialize
(
ctx
=
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
feat
=
(
F
.
randn
((
100
,
10
)),
F
.
randn
((
200
,
dst_dim
)))
sage
.
initialize
(
ctx
=
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
2
assert
h
.
shape
[
0
]
==
200
def
test_gg_conv
():
def
test_gg_conv
():
g
=
dgl
.
DGLGraph
(
nx
.
erdos_renyi_graph
(
20
,
0.3
))
g
=
dgl
.
DGLGraph
(
nx
.
erdos_renyi_graph
(
20
,
0.3
))
...
...
tests/pytorch/test_nn.py
View file @
4af02022
...
@@ -290,23 +290,28 @@ def test_edge_softmax():
...
@@ -290,23 +290,28 @@ def test_edge_softmax():
print
(
score
.
grad
[:
10
],
grad_score
[:
10
])
print
(
score
.
grad
[:
10
],
grad_score
[:
10
])
# Test 2
# Test 2
def
generate_rand_graph
(
n
):
def
generate_rand_graph
(
n
,
m
=
None
,
ctor
=
dgl
.
DGLGraph
):
arr
=
(
sp
.
sparse
.
random
(
n
,
n
,
density
=
0.1
,
format
=
'coo'
)
!=
0
).
astype
(
np
.
int64
)
if
m
is
None
:
return
dgl
.
DGLGraph
(
arr
,
readonly
=
True
)
m
=
n
arr
=
(
sp
.
sparse
.
random
(
m
,
n
,
density
=
0.1
,
format
=
'coo'
)
!=
0
).
astype
(
np
.
int64
)
g
=
generate_rand_graph
(
50
)
return
ctor
(
arr
,
readonly
=
True
)
a1
=
F
.
randn
((
g
.
number_of_edges
(),
1
)).
requires_grad_
()
a2
=
a1
.
clone
().
detach
().
requires_grad_
()
for
g
in
[
generate_rand_graph
(
50
),
g
.
edata
[
's'
]
=
a1
generate_rand_graph
(
50
,
ctor
=
dgl
.
graph
),
g
.
group_apply_edges
(
'dst'
,
lambda
edges
:
{
'ss'
:
F
.
softmax
(
edges
.
data
[
's'
],
1
)})
generate_rand_graph
(
100
,
50
,
ctor
=
dgl
.
bipartite
)]:
g
.
edata
[
'ss'
].
sum
().
backward
()
a1
=
F
.
randn
((
g
.
number_of_edges
(),
1
)).
requires_grad_
()
a2
=
a1
.
clone
().
detach
().
requires_grad_
()
builtin_sm
=
nn
.
edge_softmax
(
g
,
a2
)
g
.
edata
[
's'
]
=
a1
builtin_sm
.
sum
().
backward
()
g
.
group_apply_edges
(
'dst'
,
lambda
edges
:
{
'ss'
:
F
.
softmax
(
edges
.
data
[
's'
],
1
)})
print
(
a1
.
grad
-
a2
.
grad
)
g
.
edata
[
'ss'
].
sum
().
backward
()
assert
len
(
g
.
ndata
)
==
0
assert
len
(
g
.
edata
)
==
2
builtin_sm
=
nn
.
edge_softmax
(
g
,
a2
)
assert
F
.
allclose
(
a1
.
grad
,
a2
.
grad
,
rtol
=
1e-4
,
atol
=
1e-4
)
# Follow tolerance in unittest backend
builtin_sm
.
sum
().
backward
()
print
(
a1
.
grad
-
a2
.
grad
)
assert
len
(
g
.
srcdata
)
==
0
assert
len
(
g
.
dstdata
)
==
0
assert
len
(
g
.
edata
)
==
2
assert
F
.
allclose
(
a1
.
grad
,
a2
.
grad
,
rtol
=
1e-4
,
atol
=
1e-4
)
# Follow tolerance in unittest backend
def
test_partial_edge_softmax
():
def
test_partial_edge_softmax
():
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
...
@@ -402,6 +407,13 @@ def test_sage_conv():
...
@@ -402,6 +407,13 @@ def test_sage_conv():
h
=
sage
(
g
,
feat
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
graph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
))
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
sage
=
sage
.
to
(
ctx
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
...
...
tests/tensorflow/test_nn.py
View file @
4af02022
...
@@ -309,12 +309,27 @@ def test_gat_conv():
...
@@ -309,12 +309,27 @@ def test_gat_conv():
def
test_sage_conv
():
def
test_sage_conv
():
for
aggre_type
in
[
'mean'
,
'pool'
,
'gcn'
,
'lstm'
]:
for
aggre_type
in
[
'mean'
,
'pool'
,
'gcn'
,
'lstm'
]:
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
feat
=
F
.
randn
((
100
,
5
))
h
=
sage
(
g
,
feat
)
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
graph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
))
sage
=
nn
.
SAGEConv
(
5
,
10
,
aggre_type
)
feat
=
F
.
randn
((
100
,
5
))
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
10
g
=
dgl
.
bipartite
(
sp
.
sparse
.
random
(
100
,
200
,
density
=
0.1
))
dst_dim
=
5
if
aggre_type
!=
'gcn'
else
10
sage
=
nn
.
SAGEConv
((
10
,
dst_dim
),
2
,
aggre_type
)
feat
=
(
F
.
randn
((
100
,
10
)),
F
.
randn
((
200
,
dst_dim
)))
h
=
sage
(
g
,
feat
)
assert
h
.
shape
[
-
1
]
==
2
assert
h
.
shape
[
0
]
==
200
def
test_sgc_conv
():
def
test_sgc_conv
():
ctx
=
F
.
ctx
()
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
),
readonly
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment