Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
61139302
Unverified
Commit
61139302
authored
Dec 01, 2022
by
peizhou001
Committed by
GitHub
Dec 01, 2022
Browse files
[API Deprecation] Remove candidates in DGLGraph (#4946)
parent
e088acac
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
75 additions
and
401 deletions
+75
-401
examples/pytorch/transformer/modules/act.py
examples/pytorch/transformer/modules/act.py
+1
-1
examples/pytorch/transformer/modules/models.py
examples/pytorch/transformer/modules/models.py
+2
-2
examples/pytorch/transformer/modules/viz.py
examples/pytorch/transformer/modules/viz.py
+1
-1
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+3
-3
examples/tensorflow/gcn/gcn_builtin.py
examples/tensorflow/gcn/gcn_builtin.py
+1
-1
python/dgl/_deprecate/graph.py
python/dgl/_deprecate/graph.py
+4
-4
python/dgl/batch.py
python/dgl/batch.py
+2
-11
python/dgl/function/message.py
python/dgl/function/message.py
+3
-86
python/dgl/heterograph.py
python/dgl/heterograph.py
+17
-238
python/dgl/nn/mxnet/conv/graphconv.py
python/dgl/nn/mxnet/conv/graphconv.py
+2
-2
python/dgl/nn/mxnet/conv/tagconv.py
python/dgl/nn/mxnet/conv/tagconv.py
+1
-1
python/dgl/nn/pytorch/conv/ginconv.py
python/dgl/nn/pytorch/conv/ginconv.py
+1
-1
python/dgl/nn/pytorch/conv/graphconv.py
python/dgl/nn/pytorch/conv/graphconv.py
+3
-3
python/dgl/nn/pytorch/conv/sageconv.py
python/dgl/nn/pytorch/conv/sageconv.py
+1
-1
python/dgl/nn/pytorch/sparse_emb.py
python/dgl/nn/pytorch/sparse_emb.py
+0
-13
python/dgl/nn/tensorflow/conv/graphconv.py
python/dgl/nn/tensorflow/conv/graphconv.py
+2
-2
python/dgl/nn/tensorflow/conv/sageconv.py
python/dgl/nn/tensorflow/conv/sageconv.py
+4
-4
python/dgl/optim/pytorch/sparse_optim.py
python/dgl/optim/pytorch/sparse_optim.py
+10
-10
tests/compute/test_basics.py
tests/compute/test_basics.py
+7
-7
tests/compute/test_batched_graph.py
tests/compute/test_batched_graph.py
+10
-10
No files found.
examples/pytorch/transformer/modules/act.py
View file @
61139302
...
...
@@ -117,7 +117,7 @@ class UTransformer(nn.Module):
g
.
apply_edges
(
scaled_exp
(
'score'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
[
fn
.
src
_mul_e
dge
(
'v'
,
'score'
,
'v'
),
fn
.
copy_e
dge
(
'score'
,
'score'
)],
[
fn
.
u
_mul_e
(
'v'
,
'score'
,
'v'
),
fn
.
copy_e
(
'score'
,
'score'
)],
[
fn
.
sum
(
'v'
,
'wv'
),
fn
.
sum
(
'score'
,
'z'
)])
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
...
...
examples/pytorch/transformer/modules/models.py
View file @
61139302
...
...
@@ -79,8 +79,8 @@ class Transformer(nn.Module):
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'score'
),
eids
)
g
.
apply_edges
(
scaled_exp
(
'score'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
fn
.
src
_mul_e
dge
(
'v'
,
'score'
,
'v'
),
fn
.
sum
(
'v'
,
'wv'
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
dge
(
'score'
,
'score'
),
fn
.
sum
(
'score'
,
'z'
))
g
.
send_and_recv
(
eids
,
fn
.
u
_mul_e
(
'v'
,
'score'
,
'v'
),
fn
.
sum
(
'v'
,
'wv'
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
(
'score'
,
'score'
),
fn
.
sum
(
'score'
,
'z'
))
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
"Update the node states and edge states of the graph."
...
...
examples/pytorch/transformer/modules/viz.py
View file @
61139302
...
...
@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
for
j
,
dst
in
enumerate
(
dst_nodes
.
tolist
()):
if
not
g
.
has_edge_between
(
src
,
dst
):
continue
eid
=
g
.
edge_id
(
src
,
dst
)
eid
=
g
.
edge_id
s
(
src
,
dst
)
weight
[
i
][
j
]
=
g
.
edata
[
'score'
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
=
weight
.
transpose
(
0
,
2
)
...
...
examples/pytorch/tree_lstm/train.py
View file @
61139302
...
...
@@ -131,7 +131,7 @@ def main(args):
root_ids
=
[
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
(
i
)
==
0
if
g
.
out_degree
s
(
i
)
==
0
]
root_acc
=
np
.
sum
(
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
...
...
@@ -170,7 +170,7 @@ def main(args):
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
)).
item
()
accs
.
append
([
acc
,
len
(
batch
.
label
)])
root_ids
=
[
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
(
i
)
==
0
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
s
(
i
)
==
0
]
root_acc
=
np
.
sum
(
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
...
...
@@ -222,7 +222,7 @@ def main(args):
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
)).
item
()
accs
.
append
([
acc
,
len
(
batch
.
label
)])
root_ids
=
[
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
(
i
)
==
0
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
s
(
i
)
==
0
]
root_acc
=
np
.
sum
(
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
...
...
examples/tensorflow/gcn/gcn_builtin.py
View file @
61139302
...
...
@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer):
h
=
self
.
dropout
(
h
)
self
.
g
.
ndata
[
'h'
]
=
tf
.
matmul
(
h
,
self
.
weight
)
self
.
g
.
ndata
[
'norm_h'
]
=
self
.
g
.
ndata
[
'h'
]
*
self
.
g
.
ndata
[
'norm'
]
self
.
g
.
update_all
(
fn
.
copy_
src
(
'norm_h'
,
'm'
),
self
.
g
.
update_all
(
fn
.
copy_
u
(
'norm_h'
,
'm'
),
fn
.
sum
(
'm'
,
'h'
))
h
=
self
.
g
.
ndata
[
'h'
]
if
self
.
bias
is
not
None
:
...
...
python/dgl/_deprecate/graph.py
View file @
61139302
...
...
@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])
Use the built-in message function :func:`~dgl.function.copy_
src
` for copying
Use the built-in message function :func:`~dgl.function.copy_
u
` for copying
node features as the message.
>>> m_func = dgl.function.copy_
src
('x', 'm')
>>> m_func = dgl.function.copy_
u
('x', 'm')
>>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which
...
...
@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])
Use the built-in message function :func:`~dgl.function.copy_
src
` for copying
Use the built-in message function :func:`~dgl.function.copy_
u
` for copying
node features as the message.
>>> m_func = dgl.function.copy_
src
('x', 'm')
>>> m_func = dgl.function.copy_
u
('x', 'm')
>>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which
...
...
python/dgl/batch.py
View file @
61139302
...
...
@@ -2,7 +2,7 @@
from
collections.abc
import
Mapping
from
.
import
backend
as
F
from
.base
import
ALL
,
is_all
,
DGLError
,
dgl_warning
,
NID
,
EID
from
.base
import
ALL
,
is_all
,
DGLError
,
NID
,
EID
from
.heterograph_index
import
disjoint_union
,
slice_gidx
from
.heterograph
import
DGLGraph
from
.
import
convert
...
...
@@ -11,8 +11,7 @@ from . import utils
__all__
=
[
'batch'
,
'unbatch'
,
'slice_batch'
]
def
batch
(
graphs
,
ndata
=
ALL
,
edata
=
ALL
,
*
,
node_attrs
=
None
,
edge_attrs
=
None
):
def
batch
(
graphs
,
ndata
=
ALL
,
edata
=
ALL
):
r
"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation.
...
...
@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *,
"""
if
len
(
graphs
)
==
0
:
raise
DGLError
(
'The input list of graphs cannot be empty.'
)
if
node_attrs
is
not
None
:
dgl_warning
(
'Arguments node_attrs has been deprecated. Please use'
' ndata instead.'
)
ndata
=
node_attrs
if
edge_attrs
is
not
None
:
dgl_warning
(
'Arguments edge_attrs has been deprecated. Please use'
' edata instead.'
)
edata
=
edge_attrs
if
not
(
is_all
(
ndata
)
or
isinstance
(
ndata
,
list
)
or
ndata
is
None
):
raise
DGLError
(
'Invalid argument ndata: must be a string list but got {}.'
.
format
(
type
(
ndata
)))
...
...
python/dgl/function/message.py
View file @
61139302
...
...
@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir
from
.._deprecate.runtime.ir
import
var
__all__
=
[
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
,
"copy_u"
,
"copy_e"
,
__all__
=
[
"copy_u"
,
"copy_e"
,
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
...
...
@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction):
See Also
--------
src
_mul_e
dge
u
_mul_e
"""
def
__init__
(
self
,
binary_op
,
lhs
,
rhs
,
lhs_field
,
rhs_field
,
out_field
):
self
.
binary_op
=
binary_op
...
...
@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction):
See Also
--------
copy_
src
copy_
u
"""
def
__init__
(
self
,
target
,
in_field
,
out_field
):
self
.
target
=
target
...
...
@@ -218,86 +218,3 @@ def _register_builtin_message_func():
__all__
.
append
(
func
.
__name__
)
_register_builtin_message_func
()
##############################################################################
# For backward compatibility
def
src_mul_edge
(
src
,
edge
,
out
):
"""Builtin message function that computes message by performing
binary operation mul between src feature and edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.u_mul_e` instead.
Parameters
----------
src : str
The source feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return
getattr
(
sys
.
modules
[
__name__
],
"u_mul_e"
)(
src
,
edge
,
out
)
def
copy_src
(
src
,
out
):
"""Builtin message function that computes message using source node
feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_u` instead.
Parameters
----------
src : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_src('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return
copy_u
(
src
,
out
)
def
copy_edge
(
edge
,
out
):
"""Builtin message function that computes message using edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_e` instead.
Parameters
----------
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_edge('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return
copy_e
(
edge
,
out
)
python/dgl/heterograph.py
View file @
61139302
...
...
@@ -346,14 +346,6 @@ class DGLGraph(object):
self
.
_node_frames
[
ntid
].
append
(
data
)
self
.
_reset_cached_info
()
def
add_edge
(
self
,
u
,
v
,
data
=
None
,
etype
=
None
):
"""Add one edge to the graph.
DEPRECATED: please use ``add_edges``.
"""
dgl_warning
(
"DGLGraph.add_edge is deprecated. Please use DGLGraph.add_edges"
)
self
.
add_edges
(
u
,
v
,
data
,
etype
)
def
add_edges
(
self
,
u
,
v
,
data
=
None
,
etype
=
None
):
r
"""Add multiple new edges for the specified edge type
...
...
@@ -2623,20 +2615,6 @@ class DGLGraph(object):
"""
return
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
@
property
def
is_readonly
(
self
):
"""**DEPRECATED**: DGLGraph will always be mutable.
Returns
-------
bool
True if the graph is readonly, False otherwise.
"""
dgl_warning
(
'DGLGraph.is_readonly is deprecated in v0.5.
\n
'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.'
)
return
False
@
property
def
idtype
(
self
):
"""The data type for storing the structure-related graph information
...
...
@@ -2682,12 +2660,6 @@ class DGLGraph(object):
"""
return
self
.
_graph
.
dtype
def
__contains__
(
self
,
vid
):
"""**DEPRECATED**: please directly call :func:`has_nodes`."""
dgl_warning
(
'DGLGraph.__contains__ is deprecated.'
' Please directly call has_nodes.'
)
return
self
.
has_nodes
(
vid
)
def
has_nodes
(
self
,
vid
,
ntype
=
None
):
"""Return whether the graph contains the given nodes.
...
...
@@ -2745,14 +2717,6 @@ class DGLGraph(object):
else
:
return
F
.
astype
(
ret
,
F
.
bool
)
def
has_node
(
self
,
vid
,
ntype
=
None
):
"""Whether the graph has a particular node of a given type.
**DEPRECATED**: see :func:`~DGLGraph.has_nodes`
"""
dgl_warning
(
"DGLGraph.has_node is deprecated. Please use DGLGraph.has_nodes"
)
return
self
.
has_nodes
(
vid
,
ntype
)
def
has_edges_between
(
self
,
u
,
v
,
etype
=
None
):
"""Return whether the graph contains the given edges.
...
...
@@ -2843,15 +2807,6 @@ class DGLGraph(object):
else
:
return
F
.
astype
(
ret
,
F
.
bool
)
def
has_edge_between
(
self
,
u
,
v
,
etype
=
None
):
"""Whether the graph has edges of type ``etype``.
**DEPRECATED**: please use :func:`~DGLGraph.has_edge_between`.
"""
dgl_warning
(
"DGLGraph.has_edge_between is deprecated. "
"Please use DGLGraph.has_edges_between"
)
return
self
.
has_edges_between
(
u
,
v
,
etype
)
def
predecessors
(
self
,
v
,
etype
=
None
):
"""Return the predecessor(s) of a particular node with the specified edge type.
...
...
@@ -2969,17 +2924,7 @@ class DGLGraph(object):
raise
DGLError
(
'Non-existing node ID {}'
.
format
(
v
))
return
self
.
_graph
.
successors
(
self
.
get_etype_id
(
etype
),
v
)
def
edge_id
(
self
,
u
,
v
,
force_multi
=
None
,
return_uv
=
False
,
etype
=
None
):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`, with the specified edge type
**DEPRECATED**: See edge_ids
"""
dgl_warning
(
"DGLGraph.edge_id is deprecated. Please use DGLGraph.edge_ids."
)
return
self
.
edge_ids
(
u
,
v
,
force_multi
=
force_multi
,
return_uv
=
return_uv
,
etype
=
etype
)
def
edge_ids
(
self
,
u
,
v
,
force_multi
=
None
,
return_uv
=
False
,
etype
=
None
):
def
edge_ids
(
self
,
u
,
v
,
return_uv
=
False
,
etype
=
None
):
"""Return the edge ID(s) given the two endpoints of the edge(s).
Parameters
...
...
@@ -2999,9 +2944,6 @@ class DGLGraph(object):
* Int Tensor: Each element is a node ID. The tensor must have the same device type
and ID data type as the graph's.
* iterable[int]: Each element is a node ID.
force_multi : bool, optional
**DEPRECATED**, use :attr:`return_uv` instead. Whether to allow the graph to be a
multigraph, i.e. there can be multiple edges from one node to another.
return_uv : bool, optional
Whether to return the source and destination node IDs along with the edges. If
False (default), it assumes that the graph is a simple graph and there is only
...
...
@@ -3084,10 +3026,6 @@ class DGLGraph(object):
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
'v contains invalid node IDs'
)
if
force_multi
is
not
None
:
dgl_warning
(
"force_multi will be deprecated, "
\
"Please use return_uv instead"
)
return_uv
=
force_multi
if
return_uv
:
return
self
.
_graph
.
edge_ids_all
(
self
.
get_etype_id
(
etype
),
u
,
v
)
...
...
@@ -3424,14 +3362,6 @@ class DGLGraph(object):
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
def
in_degree
(
self
,
v
,
etype
=
None
):
"""Return the in-degree of node ``v`` with edges of type ``etype``.
**DEPRECATED**: Please use in_degrees
"""
dgl_warning
(
"DGLGraph.in_degree is deprecated. Please use DGLGraph.in_degrees"
)
return
self
.
in_degrees
(
v
,
etype
)
def
in_degrees
(
self
,
v
=
ALL
,
etype
=
None
):
"""Return the in-degree(s) of the given nodes.
...
...
@@ -3508,14 +3438,6 @@ class DGLGraph(object):
else
:
return
deg
def
out_degree
(
self
,
u
,
etype
=
None
):
"""Return the out-degree of node `u` with edges of type ``etype``.
DEPRECATED: please use DGL.out_degrees
"""
dgl_warning
(
"DGLGraph.out_degree is deprecated. Please use DGLGraph.out_degrees"
)
return
self
.
out_degrees
(
u
,
etype
)
def
out_degrees
(
self
,
u
=
ALL
,
etype
=
None
):
"""Return the out-degree(s) of the given nodes.
...
...
@@ -3713,15 +3635,6 @@ class DGLGraph(object):
else
:
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
False
,
fmt
)[
2
:]
def
adjacency_matrix_scipy
(
self
,
transpose
=
False
,
fmt
=
'csr'
,
return_edge_ids
=
None
):
"""DEPRECATED: please use ``dgl.adjacency_matrix(transpose, scipy_fmt=fmt)``.
"""
dgl_warning
(
'DGLGraph.adjacency_matrix_scipy is deprecated. '
'Please replace it with:
\n\n\t
'
'DGLGraph.adjacency_matrix(transpose, scipy_fmt="{}").
\n
'
.
format
(
fmt
))
return
self
.
adjacency_matrix
(
transpose
=
transpose
,
scipy_fmt
=
fmt
)
def
inc
(
self
,
typestr
,
ctx
=
F
.
cpu
(),
etype
=
None
):
"""Return the incidence matrix representation of edges with the given
edge type.
...
...
@@ -4283,7 +4196,7 @@ class DGLGraph(object):
# Message passing
#################################################################
def
apply_nodes
(
self
,
func
,
v
=
ALL
,
ntype
=
None
,
inplace
=
False
):
def
apply_nodes
(
self
,
func
,
v
=
ALL
,
ntype
=
None
):
"""Update the features of the specified nodes by the provided function.
Parameters
...
...
@@ -4303,8 +4216,6 @@ class DGLGraph(object):
ntype : str, optional
The node type name. Can be omitted if there is
only one type of nodes in the graph.
inplace : bool, optional
**DEPRECATED**.
Examples
--------
...
...
@@ -4340,8 +4251,6 @@ class DGLGraph(object):
--------
apply_edges
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
ntid
=
self
.
get_ntype_id
(
ntype
)
ntype
=
self
.
ntypes
[
ntid
]
if
is_all
(
v
):
...
...
@@ -4351,7 +4260,7 @@ class DGLGraph(object):
ndata
=
core
.
invoke_node_udf
(
self
,
v_id
,
ntype
,
func
,
orig_nid
=
v_id
)
self
.
_set_n_repr
(
ntid
,
v
,
ndata
)
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
):
"""Update the features of the specified edges by the provided function.
Parameters
...
...
@@ -4382,9 +4291,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
-----
DGL recommends using DGL's bulit-in function for the :attr:`func` argument,
...
...
@@ -4435,8 +4341,6 @@ class DGLGraph(object):
--------
apply_nodes
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
# Graph with one relation type
if
self
.
_graph
.
number_of_etypes
()
==
1
or
etype
is
not
None
:
etid
=
self
.
get_etype_id
(
etype
)
...
...
@@ -4476,8 +4380,7 @@ class DGLGraph(object):
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
inplace
=
False
):
etype
=
None
):
"""Send messages along the specified edges and reduce them on
the destination nodes to update their features.
...
...
@@ -4513,9 +4416,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
-----
DGL recommends using DGL's bulit-in function for the :attr:`message_func`
...
...
@@ -4558,7 +4458,7 @@ class DGLGraph(object):
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])
... })
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.send_and_recv(g['follows'].edges(), fn.copy_
src
('h', 'm'),
>>> g.send_and_recv(g['follows'].edges(), fn.copy_
u
('h', 'm'),
... fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
...
...
@@ -4588,8 +4488,6 @@ class DGLGraph(object):
Note that the feature of node 0 remains the same as it has no incoming edges.
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
# edge type
etid
=
self
.
get_etype_id
(
etype
)
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
...
...
@@ -4612,8 +4510,7 @@ class DGLGraph(object):
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
inplace
=
False
):
etype
=
None
):
"""Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features.
...
...
@@ -4645,9 +4542,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
-----
* If some of the given nodes :attr:`v` has no in-edges, DGL does not invoke
...
...
@@ -4688,14 +4582,12 @@ class DGLGraph(object):
Pull.
>>> g['follows'].pull(2, fn.copy_
src
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g['follows'].pull(2, fn.copy_
u
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
[1.],
[1.]])
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
len
(
v
)
==
0
:
# no computation
...
...
@@ -4716,8 +4608,7 @@ class DGLGraph(object):
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
inplace
=
False
):
etype
=
None
):
"""Send message from the specified node(s) to their successors
along the specified edge type and update their node features.
...
...
@@ -4749,9 +4640,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
-----
DGL recommends using DGL's bulit-in function for the :attr:`message_func`
...
...
@@ -4785,14 +4673,12 @@ class DGLGraph(object):
Push.
>>> g['follows'].push(0, fn.copy_
src
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g['follows'].push(0, fn.copy_
u
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
[0.],
[0.]])
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
edges
=
self
.
out_edges
(
u
,
form
=
'eid'
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
...
...
@@ -4864,7 +4750,7 @@ class DGLGraph(object):
Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g['follows'].update_all(fn.copy_
src
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g['follows'].update_all(fn.copy_
u
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
[0.],
...
...
@@ -4881,7 +4767,7 @@ class DGLGraph(object):
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.update_all(fn.copy_
src
('h', 'm'), fn.sum('m', 'h'))
>>> g.update_all(fn.copy_
u
('h', 'm'), fn.sum('m', 'h'))
>>> g.nodes['user'].data['h']
tensor([[0.],
[4.]])
...
...
@@ -4989,8 +4875,8 @@ class DGLGraph(object):
Update all.
>>> g.multi_update_all(
... {'follows': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h'))},
... {'follows': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h'))},
... "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
...
...
@@ -5004,8 +4890,8 @@ class DGLGraph(object):
Use the user-defined cross reducer.
>>> g.multi_update_all(
... {'follows': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h'))},
... {'follows': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h'))},
... cross_sum)
"""
all_out
=
defaultdict
(
list
)
...
...
@@ -5088,7 +4974,7 @@ class DGLGraph(object):
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_
src
('h', 'm'),
>>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_
u
('h', 'm'),
... fn.sum('m', 'h'), etype='follows')
tensor([[1.],
[2.],
...
...
@@ -5151,7 +5037,7 @@ class DGLGraph(object):
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_
src
('h', 'm'),
>>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_
u
('h', 'm'),
... fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[1.],
...
...
@@ -6062,113 +5948,6 @@ class DGLGraph(object):
"""
return
self
.
astype
(
F
.
int32
)
#################################################################
# DEPRECATED: from the old DGLGraph
#################################################################
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
"""DEPRECATED: please use
``dgl.from_networkx(nx_graph, node_attrs, edge_attrs)``
which will return a new graph created from the networkx graph.
"""
raise
DGLError
(
'DGLGraph.from_networkx is deprecated. Please call the following
\n\n
'
'
\t
dgl.from_networkx(nx_graph, node_attrs, edge_attrs)
\n\n
'
', which creates a new DGLGraph from the networkx graph.'
)
def
from_scipy_sparse_matrix
(
self
,
spmat
,
multigraph
=
None
):
"""DEPRECATED: please use
``dgl.from_scipy(spmat)``
which will return a new graph created from the scipy matrix.
"""
raise
DGLError
(
'DGLGraph.from_scipy_sparse_matrix is deprecated. '
'Please call the following
\n\n
'
'
\t
dgl.from_scipy(spmat)
\n\n
'
', which creates a new DGLGraph from the scipy matrix.'
)
def
register_apply_node_func
(
self
,
func
):
"""Deprecated: please directly call :func:`apply_nodes` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_apply_node_func is deprecated.'
' Please directly call apply_nodes with func as the argument.'
)
def
register_apply_edge_func
(
self
,
func
):
"""Deprecated: please directly call :func:`apply_edges` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_apply_edge_func is deprecated.'
' Please directly call apply_edges with func as the argument.'
)
def
register_message_func
(
self
,
func
):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_message_func is deprecated.'
' Please directly call update_all with func as the argument.'
)
def
register_reduce_func
(
self
,
func
):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_reduce_func is deprecated.'
' Please directly call update_all with func as the argument.'
)
def
group_apply_edges
(
self
,
group_by
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
"""**DEPRECATED**: The API is removed in 0.5."""
raise
DGLError
(
'DGLGraph.group_apply_edges is removed in 0.5.'
)
def
send
(
self
,
edges
,
message_func
,
etype
=
None
):
"""Send messages along the given edges with the same edge type.
DEPRECATE: please use send_and_recv, update_all.
"""
raise
DGLError
(
'DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges
\n
'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv
\n
'
' and set the message function as dgl.function.copy_e to conduct message
\n
'
' aggregation.'
)
def
recv
(
self
,
v
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
inplace
=
False
):
r
"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
DEPRECATE: please use send_and_recv, update_all.
"""
raise
DGLError
(
'DGLGraph.recv is deprecated. As a replacement, use DGLGraph.apply_edges
\n
'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv
\n
'
' and set the message function as dgl.function.copy_e to conduct message
\n
'
' aggregation.'
)
def
multi_recv
(
self
,
v
,
reducer_dict
,
cross_reducer
,
apply_node_func
=
None
,
inplace
=
False
):
r
"""Receive messages from multiple edge types and perform aggregation.
DEPRECATE: please use multi_send_and_recv, multi_update_all.
"""
raise
DGLError
(
'DGLGraph.multi_recv is deprecated. As a replacement,
\n
'
' use DGLGraph.apply_edges API to compute messages as edge data.
\n
'
' Then use DGLGraph.multi_send_and_recv and set the message function
\n
'
' as dgl.function.copy_e to conduct message aggregation.'
)
def
multi_send_and_recv
(
self
,
etype_dict
,
cross_reducer
,
apply_node_func
=
None
,
inplace
=
False
):
r
"""**DEPRECATED**: The API is removed in v0.5."""
raise
DGLError
(
'DGLGraph.multi_pull is removed in v0.5. As a replacement,
\n
'
' use DGLGraph.edge_subgraph to extract the subgraph first
\n
'
' and then call DGLGraph.multi_update_all.'
)
def
multi_pull
(
self
,
v
,
etype_dict
,
cross_reducer
,
apply_node_func
=
None
,
inplace
=
False
):
r
"""**DEPRECATED**: The API is removed in v0.5."""
raise
DGLError
(
'DGLGraph.multi_pull is removed in v0.5. As a replacement,
\n
'
' use DGLGraph.edge_subgraph to extract the subgraph first
\n
'
' and then call DGLGraph.multi_update_all.'
)
def
readonly
(
self
,
readonly_state
=
True
):
"""Deprecated: DGLGraph will always be mutable."""
dgl_warning
(
'DGLGraph.readonly is deprecated in v0.5.
\n
'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.'
)
############################################################
# Internal APIs
############################################################
...
...
python/dgl/nn/mxnet/conv/graphconv.py
View file @
61139302
...
...
@@ -261,13 +261,13 @@ class GraphConv(gluon.Block):
if
weight
is
not
None
:
feat_src
=
mx
.
nd
.
dot
(
feat_src
,
weight
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
.
pop
(
'h'
)
else
:
# aggregate first then mult W
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
.
pop
(
'h'
)
if
weight
is
not
None
:
...
...
python/dgl/nn/mxnet/conv/tagconv.py
View file @
61139302
...
...
@@ -114,7 +114,7 @@ class TAGConv(gluon.Block):
rst
=
rst
*
norm
graph
.
ndata
[
'h'
]
=
rst
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
ndata
[
'h'
]
rst
=
rst
*
norm
...
...
python/dgl/nn/pytorch/conv/ginconv.py
View file @
61139302
...
...
@@ -136,7 +136,7 @@ class GINConv(nn.Module):
"""
_reducer
=
getattr
(
fn
,
self
.
_aggregator_type
)
with
graph
.
local_scope
():
aggregate_fn
=
fn
.
copy_
src
(
'h'
,
'm'
)
aggregate_fn
=
fn
.
copy_
u
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
...
...
python/dgl/nn/pytorch/conv/graphconv.py
View file @
61139302
...
...
@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module):
if
self
.
_norm
==
'both'
:
reversed_g
=
reverse
(
graph
)
reversed_g
.
edata
[
'_edge_w'
]
=
edge_weight
reversed_g
.
update_all
(
fn
.
copy_e
dge
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'out_weight'
))
reversed_g
.
update_all
(
fn
.
copy_e
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'out_weight'
))
degs
=
reversed_g
.
dstdata
[
'out_weight'
]
+
self
.
_eps
norm
=
th
.
pow
(
degs
,
-
0.5
)
graph
.
srcdata
[
'_src_out_w'
]
=
norm
if
self
.
_norm
!=
'none'
:
graph
.
update_all
(
fn
.
copy_e
dge
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'in_weight'
))
graph
.
update_all
(
fn
.
copy_e
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'in_weight'
))
degs
=
graph
.
dstdata
[
'in_weight'
]
+
self
.
_eps
if
self
.
_norm
==
'both'
:
norm
=
th
.
pow
(
degs
,
-
0.5
)
...
...
@@ -389,7 +389,7 @@ class GraphConv(nn.Module):
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.'
)
aggregate_fn
=
fn
.
copy_
src
(
'h'
,
'm'
)
aggregate_fn
=
fn
.
copy_
u
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
...
...
python/dgl/nn/pytorch/conv/sageconv.py
View file @
61139302
...
...
@@ -213,7 +213,7 @@ class SAGEConv(nn.Module):
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
if
graph
.
is_block
:
feat_dst
=
feat_src
[:
graph
.
number_of_dst_nodes
()]
msg_fn
=
fn
.
copy_
src
(
'h'
,
'm'
)
msg_fn
=
fn
.
copy_
u
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
...
...
python/dgl/nn/pytorch/sparse_emb.py
View file @
61139302
...
...
@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding
"""
self
.
_trace
=
[]
@
property
def
emb_tensor
(
self
):
"""Return the tensor storing the node embeddings
DEPRECATED: renamed weight
Returns
-------
torch.Tensor
The tensor storing the node embeddings
"""
return
self
.
_tensor
@
property
def
weight
(
self
):
"""Return the tensor storing the node embeddings
...
...
python/dgl/nn/tensorflow/conv/graphconv.py
View file @
61139302
...
...
@@ -253,13 +253,13 @@ class GraphConv(layers.Layer):
if
weight
is
not
None
:
feat_src
=
tf
.
matmul
(
feat_src
,
weight
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
[
'h'
]
else
:
# aggregate first then mult W
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
[
'h'
]
if
weight
is
not
None
:
...
...
python/dgl/nn/tensorflow/conv/sageconv.py
View file @
61139302
...
...
@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer):
if
self
.
_aggre_type
==
'mean'
:
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
dstdata
[
'neigh'
]
elif
self
.
_aggre_type
==
'gcn'
:
check_eq_shape
(
feat
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in_degrees
degs
=
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
)
h_neigh
=
(
graph
.
dstdata
[
'neigh'
]
+
graph
.
dstdata
[
'h'
]
)
/
(
tf
.
expand_dims
(
degs
,
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
graph
.
srcdata
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat_src
))
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
dstdata
[
'neigh'
]
elif
self
.
_aggre_type
==
'lstm'
:
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
h_neigh
=
graph
.
dstdata
[
'neigh'
]
else
:
raise
KeyError
(
...
...
python/dgl/optim/pytorch/sparse_optim.py
View file @
61139302
...
...
@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer):
),
"SparseAdagrad only supports dgl.nn.NodeEmbedding"
emb_name
=
emb
.
name
if
th
.
device
(
emb
.
emb_tensor
.
device
)
==
th
.
device
(
"cpu"
):
if
th
.
device
(
emb
.
weight
.
device
)
==
th
.
device
(
"cpu"
):
# if our embedding is on the CPU, our state also has to be
if
self
.
_rank
<
0
:
state
=
th
.
empty
(
...
...
@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer):
else
:
# distributed state on on gpu
state
=
th
.
empty
(
emb
.
emb_tensor
.
shape
,
emb
.
weight
.
shape
,
dtype
=
th
.
float32
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
emb
.
set_optm_state
(
state
)
...
...
@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer):
),
"SparseAdam only supports dgl.nn.NodeEmbedding"
emb_name
=
emb
.
name
self
.
_is_using_uva
[
emb_name
]
=
self
.
_use_uva
if
th
.
device
(
emb
.
emb_tensor
.
device
)
==
th
.
device
(
"cpu"
):
if
th
.
device
(
emb
.
weight
.
device
)
==
th
.
device
(
"cpu"
):
# if our embedding is on the CPU, our state also has to be
if
self
.
_rank
<
0
:
state_step
=
th
.
empty
(
...
...
@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer):
# distributed state on on gpu
state_step
=
th
.
empty
(
[
emb
.
emb_tensor
.
shape
[
0
]],
[
emb
.
weight
.
shape
[
0
]],
dtype
=
th
.
int32
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
state_mem
=
th
.
empty
(
emb
.
emb_tensor
.
shape
,
emb
.
weight
.
shape
,
dtype
=
self
.
_dtype
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
state_power
=
th
.
empty
(
emb
.
emb_tensor
.
shape
,
emb
.
weight
.
shape
,
dtype
=
self
.
_dtype
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
state
=
(
state_step
,
state_mem
,
state_power
)
emb
.
set_optm_state
(
state
)
...
...
tests/compute/test_basics.py
View file @
61139302
...
...
@@ -32,10 +32,10 @@ def generate_graph_old(grad=False):
# create a graph where 0 is the source and 9 is the sink
# 17 edges
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
s
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
g
.
add_edge
s
(
9
,
0
)
g
=
g
.
to
(
F
.
ctx
())
ncol
=
F
.
randn
((
10
,
D
))
ecol
=
F
.
randn
((
17
,
D
))
...
...
@@ -431,8 +431,8 @@ def test_dynamic_addition():
assert
g
.
ndata
[
'h1'
].
shape
[
0
]
==
g
.
ndata
[
'h2'
].
shape
[
0
]
==
N
+
3
# Test edge addition
g
.
add_edge
(
0
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
s
(
0
,
1
)
g
.
add_edge
s
(
1
,
0
)
g
.
edata
.
update
({
'h1'
:
F
.
randn
((
2
,
D
)),
'h2'
:
F
.
randn
((
2
,
D
))})
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
2
...
...
@@ -441,12 +441,12 @@ def test_dynamic_addition():
g
.
edata
[
'h1'
]
=
F
.
randn
((
4
,
D
))
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
4
g
.
add_edge
(
1
,
2
)
g
.
add_edge
s
(
1
,
2
)
g
.
edges
[
4
].
data
[
'h1'
]
=
F
.
randn
((
1
,
D
))
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
5
# test add edge with part of the features
g
.
add_edge
(
2
,
1
,
{
'h1'
:
F
.
randn
((
1
,
D
))})
g
.
add_edge
s
(
2
,
1
,
{
'h1'
:
F
.
randn
((
1
,
D
))})
assert
len
(
g
.
edata
[
'h1'
])
==
len
(
g
.
edata
[
'h2'
])
...
...
tests/compute/test_batched_graph.py
View file @
61139302
...
...
@@ -15,10 +15,10 @@ def tree1(idtype):
"""
g
=
dgl
.
graph
(([],
[])).
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
5
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
s
(
3
,
1
)
g
.
add_edge
s
(
4
,
1
)
g
.
add_edge
s
(
1
,
0
)
g
.
add_edge
s
(
2
,
0
)
g
.
ndata
[
'h'
]
=
F
.
tensor
([
0
,
1
,
2
,
3
,
4
])
g
.
edata
[
'h'
]
=
F
.
randn
((
4
,
10
))
return
g
...
...
@@ -34,10 +34,10 @@ def tree2(idtype):
"""
g
=
dgl
.
graph
(([],
[])).
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
5
)
g
.
add_edge
(
2
,
4
)
g
.
add_edge
(
0
,
4
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
s
(
2
,
4
)
g
.
add_edge
s
(
0
,
4
)
g
.
add_edge
s
(
4
,
1
)
g
.
add_edge
s
(
3
,
1
)
g
.
ndata
[
'h'
]
=
F
.
tensor
([
0
,
1
,
2
,
3
,
4
])
g
.
edata
[
'h'
]
=
F
.
randn
((
4
,
10
))
return
g
...
...
@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype):
e2
=
F
.
randn
((
6
,
10
))
g2
.
edata
[
'h'
]
=
e2
g
=
dgl
.
batch
([
g1
,
g2
])
r1
=
g
.
edata
[
'h'
][
g
.
edge_id
(
4
,
5
)]
r2
=
g1
.
edata
[
'h'
][
g1
.
edge_id
(
4
,
5
)]
r1
=
g
.
edata
[
'h'
][
g
.
edge_id
s
(
4
,
5
)]
r2
=
g1
.
edata
[
'h'
][
g1
.
edge_id
s
(
4
,
5
)]
assert
F
.
array_equal
(
r1
,
r2
)
@
parametrize_idtype
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment