Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
61139302
Unverified
Commit
61139302
authored
Dec 01, 2022
by
peizhou001
Committed by
GitHub
Dec 01, 2022
Browse files
[API Deprecation] Remove candidates in DGLGraph (#4946)
parent
e088acac
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
75 additions
and
401 deletions
+75
-401
examples/pytorch/transformer/modules/act.py
examples/pytorch/transformer/modules/act.py
+1
-1
examples/pytorch/transformer/modules/models.py
examples/pytorch/transformer/modules/models.py
+2
-2
examples/pytorch/transformer/modules/viz.py
examples/pytorch/transformer/modules/viz.py
+1
-1
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+3
-3
examples/tensorflow/gcn/gcn_builtin.py
examples/tensorflow/gcn/gcn_builtin.py
+1
-1
python/dgl/_deprecate/graph.py
python/dgl/_deprecate/graph.py
+4
-4
python/dgl/batch.py
python/dgl/batch.py
+2
-11
python/dgl/function/message.py
python/dgl/function/message.py
+3
-86
python/dgl/heterograph.py
python/dgl/heterograph.py
+17
-238
python/dgl/nn/mxnet/conv/graphconv.py
python/dgl/nn/mxnet/conv/graphconv.py
+2
-2
python/dgl/nn/mxnet/conv/tagconv.py
python/dgl/nn/mxnet/conv/tagconv.py
+1
-1
python/dgl/nn/pytorch/conv/ginconv.py
python/dgl/nn/pytorch/conv/ginconv.py
+1
-1
python/dgl/nn/pytorch/conv/graphconv.py
python/dgl/nn/pytorch/conv/graphconv.py
+3
-3
python/dgl/nn/pytorch/conv/sageconv.py
python/dgl/nn/pytorch/conv/sageconv.py
+1
-1
python/dgl/nn/pytorch/sparse_emb.py
python/dgl/nn/pytorch/sparse_emb.py
+0
-13
python/dgl/nn/tensorflow/conv/graphconv.py
python/dgl/nn/tensorflow/conv/graphconv.py
+2
-2
python/dgl/nn/tensorflow/conv/sageconv.py
python/dgl/nn/tensorflow/conv/sageconv.py
+4
-4
python/dgl/optim/pytorch/sparse_optim.py
python/dgl/optim/pytorch/sparse_optim.py
+10
-10
tests/compute/test_basics.py
tests/compute/test_basics.py
+7
-7
tests/compute/test_batched_graph.py
tests/compute/test_batched_graph.py
+10
-10
No files found.
examples/pytorch/transformer/modules/act.py
View file @
61139302
...
@@ -117,7 +117,7 @@ class UTransformer(nn.Module):
...
@@ -117,7 +117,7 @@ class UTransformer(nn.Module):
g
.
apply_edges
(
scaled_exp
(
'score'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
g
.
apply_edges
(
scaled_exp
(
'score'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
g
.
send_and_recv
(
eids
,
[
fn
.
src
_mul_e
dge
(
'v'
,
'score'
,
'v'
),
fn
.
copy_e
dge
(
'score'
,
'score'
)],
[
fn
.
u
_mul_e
(
'v'
,
'score'
,
'v'
),
fn
.
copy_e
(
'score'
,
'score'
)],
[
fn
.
sum
(
'v'
,
'wv'
),
fn
.
sum
(
'score'
,
'z'
)])
[
fn
.
sum
(
'v'
,
'wv'
),
fn
.
sum
(
'score'
,
'z'
)])
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
...
...
examples/pytorch/transformer/modules/models.py
View file @
61139302
...
@@ -79,8 +79,8 @@ class Transformer(nn.Module):
...
@@ -79,8 +79,8 @@ class Transformer(nn.Module):
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'score'
),
eids
)
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'score'
),
eids
)
g
.
apply_edges
(
scaled_exp
(
'score'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
g
.
apply_edges
(
scaled_exp
(
'score'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
fn
.
src
_mul_e
dge
(
'v'
,
'score'
,
'v'
),
fn
.
sum
(
'v'
,
'wv'
))
g
.
send_and_recv
(
eids
,
fn
.
u
_mul_e
(
'v'
,
'score'
,
'v'
),
fn
.
sum
(
'v'
,
'wv'
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
dge
(
'score'
,
'score'
),
fn
.
sum
(
'score'
,
'z'
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
(
'score'
,
'score'
),
fn
.
sum
(
'score'
,
'z'
))
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
"Update the node states and edge states of the graph."
"Update the node states and edge states of the graph."
...
...
examples/pytorch/transformer/modules/viz.py
View file @
61139302
...
@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
...
@@ -17,7 +17,7 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
for
j
,
dst
in
enumerate
(
dst_nodes
.
tolist
()):
for
j
,
dst
in
enumerate
(
dst_nodes
.
tolist
()):
if
not
g
.
has_edge_between
(
src
,
dst
):
if
not
g
.
has_edge_between
(
src
,
dst
):
continue
continue
eid
=
g
.
edge_id
(
src
,
dst
)
eid
=
g
.
edge_id
s
(
src
,
dst
)
weight
[
i
][
j
]
=
g
.
edata
[
'score'
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
[
i
][
j
]
=
g
.
edata
[
'score'
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
=
weight
.
transpose
(
0
,
2
)
weight
=
weight
.
transpose
(
0
,
2
)
...
...
examples/pytorch/tree_lstm/train.py
View file @
61139302
...
@@ -131,7 +131,7 @@ def main(args):
...
@@ -131,7 +131,7 @@ def main(args):
root_ids
=
[
root_ids
=
[
i
i
for
i
in
range
(
g
.
number_of_nodes
())
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
(
i
)
==
0
if
g
.
out_degree
s
(
i
)
==
0
]
]
root_acc
=
np
.
sum
(
root_acc
=
np
.
sum
(
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
...
@@ -170,7 +170,7 @@ def main(args):
...
@@ -170,7 +170,7 @@ def main(args):
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
)).
item
()
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
)).
item
()
accs
.
append
([
acc
,
len
(
batch
.
label
)])
accs
.
append
([
acc
,
len
(
batch
.
label
)])
root_ids
=
[
root_ids
=
[
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
(
i
)
==
0
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
s
(
i
)
==
0
]
]
root_acc
=
np
.
sum
(
root_acc
=
np
.
sum
(
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
...
@@ -222,7 +222,7 @@ def main(args):
...
@@ -222,7 +222,7 @@ def main(args):
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
)).
item
()
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
)).
item
()
accs
.
append
([
acc
,
len
(
batch
.
label
)])
accs
.
append
([
acc
,
len
(
batch
.
label
)])
root_ids
=
[
root_ids
=
[
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
(
i
)
==
0
i
for
i
in
range
(
g
.
number_of_nodes
())
if
g
.
out_degree
s
(
i
)
==
0
]
]
root_acc
=
np
.
sum
(
root_acc
=
np
.
sum
(
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
batch
.
label
.
cpu
().
data
.
numpy
()[
root_ids
]
...
...
examples/tensorflow/gcn/gcn_builtin.py
View file @
61139302
...
@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer):
...
@@ -45,7 +45,7 @@ class GCNLayer(layers.Layer):
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
self
.
g
.
ndata
[
'h'
]
=
tf
.
matmul
(
h
,
self
.
weight
)
self
.
g
.
ndata
[
'h'
]
=
tf
.
matmul
(
h
,
self
.
weight
)
self
.
g
.
ndata
[
'norm_h'
]
=
self
.
g
.
ndata
[
'h'
]
*
self
.
g
.
ndata
[
'norm'
]
self
.
g
.
ndata
[
'norm_h'
]
=
self
.
g
.
ndata
[
'h'
]
*
self
.
g
.
ndata
[
'norm'
]
self
.
g
.
update_all
(
fn
.
copy_
src
(
'norm_h'
,
'm'
),
self
.
g
.
update_all
(
fn
.
copy_
u
(
'norm_h'
,
'm'
),
fn
.
sum
(
'm'
,
'h'
))
fn
.
sum
(
'm'
,
'h'
))
h
=
self
.
g
.
ndata
[
'h'
]
h
=
self
.
g
.
ndata
[
'h'
]
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
...
python/dgl/_deprecate/graph.py
View file @
61139302
...
@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3083,10 +3083,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3)
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])
>>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])
Use the built-in message function :func:`~dgl.function.copy_
src
` for copying
Use the built-in message function :func:`~dgl.function.copy_
u
` for copying
node features as the message.
node features as the message.
>>> m_func = dgl.function.copy_
src
('x', 'm')
>>> m_func = dgl.function.copy_
u
('x', 'm')
>>> g.register_message_func(m_func)
>>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which
Use the built-int message reducing function :func:`~dgl.function.sum`, which
...
@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3180,10 +3180,10 @@ class DGLGraph(DGLBaseGraph):
>>> g.add_nodes(3)
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])
Use the built-in message function :func:`~dgl.function.copy_
src
` for copying
Use the built-in message function :func:`~dgl.function.copy_
u
` for copying
node features as the message.
node features as the message.
>>> m_func = dgl.function.copy_
src
('x', 'm')
>>> m_func = dgl.function.copy_
u
('x', 'm')
>>> g.register_message_func(m_func)
>>> g.register_message_func(m_func)
Use the built-int message reducing function :func:`~dgl.function.sum`, which
Use the built-int message reducing function :func:`~dgl.function.sum`, which
...
...
python/dgl/batch.py
View file @
61139302
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.base
import
ALL
,
is_all
,
DGLError
,
dgl_warning
,
NID
,
EID
from
.base
import
ALL
,
is_all
,
DGLError
,
NID
,
EID
from
.heterograph_index
import
disjoint_union
,
slice_gidx
from
.heterograph_index
import
disjoint_union
,
slice_gidx
from
.heterograph
import
DGLGraph
from
.heterograph
import
DGLGraph
from
.
import
convert
from
.
import
convert
...
@@ -11,8 +11,7 @@ from . import utils
...
@@ -11,8 +11,7 @@ from . import utils
__all__
=
[
'batch'
,
'unbatch'
,
'slice_batch'
]
__all__
=
[
'batch'
,
'unbatch'
,
'slice_batch'
]
def
batch
(
graphs
,
ndata
=
ALL
,
edata
=
ALL
,
*
,
def
batch
(
graphs
,
ndata
=
ALL
,
edata
=
ALL
):
node_attrs
=
None
,
edge_attrs
=
None
):
r
"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
r
"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation.
graph computation.
...
@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *,
...
@@ -151,14 +150,6 @@ def batch(graphs, ndata=ALL, edata=ALL, *,
"""
"""
if
len
(
graphs
)
==
0
:
if
len
(
graphs
)
==
0
:
raise
DGLError
(
'The input list of graphs cannot be empty.'
)
raise
DGLError
(
'The input list of graphs cannot be empty.'
)
if
node_attrs
is
not
None
:
dgl_warning
(
'Arguments node_attrs has been deprecated. Please use'
' ndata instead.'
)
ndata
=
node_attrs
if
edge_attrs
is
not
None
:
dgl_warning
(
'Arguments edge_attrs has been deprecated. Please use'
' edata instead.'
)
edata
=
edge_attrs
if
not
(
is_all
(
ndata
)
or
isinstance
(
ndata
,
list
)
or
ndata
is
None
):
if
not
(
is_all
(
ndata
)
or
isinstance
(
ndata
,
list
)
or
ndata
is
None
):
raise
DGLError
(
'Invalid argument ndata: must be a string list but got {}.'
.
format
(
raise
DGLError
(
'Invalid argument ndata: must be a string list but got {}.'
.
format
(
type
(
ndata
)))
type
(
ndata
)))
...
...
python/dgl/function/message.py
View file @
61139302
...
@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir
...
@@ -9,7 +9,7 @@ from .._deprecate.runtime import ir
from
.._deprecate.runtime.ir
import
var
from
.._deprecate.runtime.ir
import
var
__all__
=
[
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
,
"copy_u"
,
"copy_e"
,
__all__
=
[
"copy_u"
,
"copy_e"
,
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
...
@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction):
...
@@ -34,7 +34,7 @@ class BinaryMessageFunction(MessageFunction):
See Also
See Also
--------
--------
src
_mul_e
dge
u
_mul_e
"""
"""
def
__init__
(
self
,
binary_op
,
lhs
,
rhs
,
lhs_field
,
rhs_field
,
out_field
):
def
__init__
(
self
,
binary_op
,
lhs
,
rhs
,
lhs_field
,
rhs_field
,
out_field
):
self
.
binary_op
=
binary_op
self
.
binary_op
=
binary_op
...
@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction):
...
@@ -73,7 +73,7 @@ class CopyMessageFunction(MessageFunction):
See Also
See Also
--------
--------
copy_
src
copy_
u
"""
"""
def
__init__
(
self
,
target
,
in_field
,
out_field
):
def
__init__
(
self
,
target
,
in_field
,
out_field
):
self
.
target
=
target
self
.
target
=
target
...
@@ -218,86 +218,3 @@ def _register_builtin_message_func():
...
@@ -218,86 +218,3 @@ def _register_builtin_message_func():
__all__
.
append
(
func
.
__name__
)
__all__
.
append
(
func
.
__name__
)
_register_builtin_message_func
()
_register_builtin_message_func
()
##############################################################################
# For backward compatibility
def
src_mul_edge
(
src
,
edge
,
out
):
"""Builtin message function that computes message by performing
binary operation mul between src feature and edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.u_mul_e` instead.
Parameters
----------
src : str
The source feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return
getattr
(
sys
.
modules
[
__name__
],
"u_mul_e"
)(
src
,
edge
,
out
)
def
copy_src
(
src
,
out
):
"""Builtin message function that computes message using source node
feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_u` instead.
Parameters
----------
src : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_src('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return
copy_u
(
src
,
out
)
def
copy_edge
(
edge
,
out
):
"""Builtin message function that computes message using edge feature.
Notes
-----
This function is deprecated. Please use :func:`~dgl.function.copy_e` instead.
Parameters
----------
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_edge('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return
copy_e
(
edge
,
out
)
python/dgl/heterograph.py
View file @
61139302
...
@@ -346,14 +346,6 @@ class DGLGraph(object):
...
@@ -346,14 +346,6 @@ class DGLGraph(object):
self
.
_node_frames
[
ntid
].
append
(
data
)
self
.
_node_frames
[
ntid
].
append
(
data
)
self
.
_reset_cached_info
()
self
.
_reset_cached_info
()
def
add_edge
(
self
,
u
,
v
,
data
=
None
,
etype
=
None
):
"""Add one edge to the graph.
DEPRECATED: please use ``add_edges``.
"""
dgl_warning
(
"DGLGraph.add_edge is deprecated. Please use DGLGraph.add_edges"
)
self
.
add_edges
(
u
,
v
,
data
,
etype
)
def
add_edges
(
self
,
u
,
v
,
data
=
None
,
etype
=
None
):
def
add_edges
(
self
,
u
,
v
,
data
=
None
,
etype
=
None
):
r
"""Add multiple new edges for the specified edge type
r
"""Add multiple new edges for the specified edge type
...
@@ -2623,20 +2615,6 @@ class DGLGraph(object):
...
@@ -2623,20 +2615,6 @@ class DGLGraph(object):
"""
"""
return
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
return
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
@
property
def
is_readonly
(
self
):
"""**DEPRECATED**: DGLGraph will always be mutable.
Returns
-------
bool
True if the graph is readonly, False otherwise.
"""
dgl_warning
(
'DGLGraph.is_readonly is deprecated in v0.5.
\n
'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.'
)
return
False
@
property
@
property
def
idtype
(
self
):
def
idtype
(
self
):
"""The data type for storing the structure-related graph information
"""The data type for storing the structure-related graph information
...
@@ -2682,12 +2660,6 @@ class DGLGraph(object):
...
@@ -2682,12 +2660,6 @@ class DGLGraph(object):
"""
"""
return
self
.
_graph
.
dtype
return
self
.
_graph
.
dtype
def
__contains__
(
self
,
vid
):
"""**DEPRECATED**: please directly call :func:`has_nodes`."""
dgl_warning
(
'DGLGraph.__contains__ is deprecated.'
' Please directly call has_nodes.'
)
return
self
.
has_nodes
(
vid
)
def
has_nodes
(
self
,
vid
,
ntype
=
None
):
def
has_nodes
(
self
,
vid
,
ntype
=
None
):
"""Return whether the graph contains the given nodes.
"""Return whether the graph contains the given nodes.
...
@@ -2745,14 +2717,6 @@ class DGLGraph(object):
...
@@ -2745,14 +2717,6 @@ class DGLGraph(object):
else
:
else
:
return
F
.
astype
(
ret
,
F
.
bool
)
return
F
.
astype
(
ret
,
F
.
bool
)
def
has_node
(
self
,
vid
,
ntype
=
None
):
"""Whether the graph has a particular node of a given type.
**DEPRECATED**: see :func:`~DGLGraph.has_nodes`
"""
dgl_warning
(
"DGLGraph.has_node is deprecated. Please use DGLGraph.has_nodes"
)
return
self
.
has_nodes
(
vid
,
ntype
)
def
has_edges_between
(
self
,
u
,
v
,
etype
=
None
):
def
has_edges_between
(
self
,
u
,
v
,
etype
=
None
):
"""Return whether the graph contains the given edges.
"""Return whether the graph contains the given edges.
...
@@ -2843,15 +2807,6 @@ class DGLGraph(object):
...
@@ -2843,15 +2807,6 @@ class DGLGraph(object):
else
:
else
:
return
F
.
astype
(
ret
,
F
.
bool
)
return
F
.
astype
(
ret
,
F
.
bool
)
def
has_edge_between
(
self
,
u
,
v
,
etype
=
None
):
"""Whether the graph has edges of type ``etype``.
**DEPRECATED**: please use :func:`~DGLGraph.has_edge_between`.
"""
dgl_warning
(
"DGLGraph.has_edge_between is deprecated. "
"Please use DGLGraph.has_edges_between"
)
return
self
.
has_edges_between
(
u
,
v
,
etype
)
def
predecessors
(
self
,
v
,
etype
=
None
):
def
predecessors
(
self
,
v
,
etype
=
None
):
"""Return the predecessor(s) of a particular node with the specified edge type.
"""Return the predecessor(s) of a particular node with the specified edge type.
...
@@ -2969,17 +2924,7 @@ class DGLGraph(object):
...
@@ -2969,17 +2924,7 @@ class DGLGraph(object):
raise
DGLError
(
'Non-existing node ID {}'
.
format
(
v
))
raise
DGLError
(
'Non-existing node ID {}'
.
format
(
v
))
return
self
.
_graph
.
successors
(
self
.
get_etype_id
(
etype
),
v
)
return
self
.
_graph
.
successors
(
self
.
get_etype_id
(
etype
),
v
)
def
edge_id
(
self
,
u
,
v
,
force_multi
=
None
,
return_uv
=
False
,
etype
=
None
):
def
edge_ids
(
self
,
u
,
v
,
return_uv
=
False
,
etype
=
None
):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`, with the specified edge type
**DEPRECATED**: See edge_ids
"""
dgl_warning
(
"DGLGraph.edge_id is deprecated. Please use DGLGraph.edge_ids."
)
return
self
.
edge_ids
(
u
,
v
,
force_multi
=
force_multi
,
return_uv
=
return_uv
,
etype
=
etype
)
def
edge_ids
(
self
,
u
,
v
,
force_multi
=
None
,
return_uv
=
False
,
etype
=
None
):
"""Return the edge ID(s) given the two endpoints of the edge(s).
"""Return the edge ID(s) given the two endpoints of the edge(s).
Parameters
Parameters
...
@@ -2999,9 +2944,6 @@ class DGLGraph(object):
...
@@ -2999,9 +2944,6 @@ class DGLGraph(object):
* Int Tensor: Each element is a node ID. The tensor must have the same device type
* Int Tensor: Each element is a node ID. The tensor must have the same device type
and ID data type as the graph's.
and ID data type as the graph's.
* iterable[int]: Each element is a node ID.
* iterable[int]: Each element is a node ID.
force_multi : bool, optional
**DEPRECATED**, use :attr:`return_uv` instead. Whether to allow the graph to be a
multigraph, i.e. there can be multiple edges from one node to another.
return_uv : bool, optional
return_uv : bool, optional
Whether to return the source and destination node IDs along with the edges. If
Whether to return the source and destination node IDs along with the edges. If
False (default), it assumes that the graph is a simple graph and there is only
False (default), it assumes that the graph is a simple graph and there is only
...
@@ -3084,10 +3026,6 @@ class DGLGraph(object):
...
@@ -3084,10 +3026,6 @@ class DGLGraph(object):
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
'v contains invalid node IDs'
)
raise
DGLError
(
'v contains invalid node IDs'
)
if
force_multi
is
not
None
:
dgl_warning
(
"force_multi will be deprecated, "
\
"Please use return_uv instead"
)
return_uv
=
force_multi
if
return_uv
:
if
return_uv
:
return
self
.
_graph
.
edge_ids_all
(
self
.
get_etype_id
(
etype
),
u
,
v
)
return
self
.
_graph
.
edge_ids_all
(
self
.
get_etype_id
(
etype
),
u
,
v
)
...
@@ -3424,14 +3362,6 @@ class DGLGraph(object):
...
@@ -3424,14 +3362,6 @@ class DGLGraph(object):
else
:
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
def
in_degree
(
self
,
v
,
etype
=
None
):
"""Return the in-degree of node ``v`` with edges of type ``etype``.
**DEPRECATED**: Please use in_degrees
"""
dgl_warning
(
"DGLGraph.in_degree is deprecated. Please use DGLGraph.in_degrees"
)
return
self
.
in_degrees
(
v
,
etype
)
def
in_degrees
(
self
,
v
=
ALL
,
etype
=
None
):
def
in_degrees
(
self
,
v
=
ALL
,
etype
=
None
):
"""Return the in-degree(s) of the given nodes.
"""Return the in-degree(s) of the given nodes.
...
@@ -3508,14 +3438,6 @@ class DGLGraph(object):
...
@@ -3508,14 +3438,6 @@ class DGLGraph(object):
else
:
else
:
return
deg
return
deg
def
out_degree
(
self
,
u
,
etype
=
None
):
"""Return the out-degree of node `u` with edges of type ``etype``.
DEPRECATED: please use DGL.out_degrees
"""
dgl_warning
(
"DGLGraph.out_degree is deprecated. Please use DGLGraph.out_degrees"
)
return
self
.
out_degrees
(
u
,
etype
)
def
out_degrees
(
self
,
u
=
ALL
,
etype
=
None
):
def
out_degrees
(
self
,
u
=
ALL
,
etype
=
None
):
"""Return the out-degree(s) of the given nodes.
"""Return the out-degree(s) of the given nodes.
...
@@ -3713,15 +3635,6 @@ class DGLGraph(object):
...
@@ -3713,15 +3635,6 @@ class DGLGraph(object):
else
:
else
:
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
False
,
fmt
)[
2
:]
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
False
,
fmt
)[
2
:]
def
adjacency_matrix_scipy
(
self
,
transpose
=
False
,
fmt
=
'csr'
,
return_edge_ids
=
None
):
"""DEPRECATED: please use ``dgl.adjacency_matrix(transpose, scipy_fmt=fmt)``.
"""
dgl_warning
(
'DGLGraph.adjacency_matrix_scipy is deprecated. '
'Please replace it with:
\n\n\t
'
'DGLGraph.adjacency_matrix(transpose, scipy_fmt="{}").
\n
'
.
format
(
fmt
))
return
self
.
adjacency_matrix
(
transpose
=
transpose
,
scipy_fmt
=
fmt
)
def
inc
(
self
,
typestr
,
ctx
=
F
.
cpu
(),
etype
=
None
):
def
inc
(
self
,
typestr
,
ctx
=
F
.
cpu
(),
etype
=
None
):
"""Return the incidence matrix representation of edges with the given
"""Return the incidence matrix representation of edges with the given
edge type.
edge type.
...
@@ -4283,7 +4196,7 @@ class DGLGraph(object):
...
@@ -4283,7 +4196,7 @@ class DGLGraph(object):
# Message passing
# Message passing
#################################################################
#################################################################
def
apply_nodes
(
self
,
func
,
v
=
ALL
,
ntype
=
None
,
inplace
=
False
):
def
apply_nodes
(
self
,
func
,
v
=
ALL
,
ntype
=
None
):
"""Update the features of the specified nodes by the provided function.
"""Update the features of the specified nodes by the provided function.
Parameters
Parameters
...
@@ -4303,8 +4216,6 @@ class DGLGraph(object):
...
@@ -4303,8 +4216,6 @@ class DGLGraph(object):
ntype : str, optional
ntype : str, optional
The node type name. Can be omitted if there is
The node type name. Can be omitted if there is
only one type of nodes in the graph.
only one type of nodes in the graph.
inplace : bool, optional
**DEPRECATED**.
Examples
Examples
--------
--------
...
@@ -4340,8 +4251,6 @@ class DGLGraph(object):
...
@@ -4340,8 +4251,6 @@ class DGLGraph(object):
--------
--------
apply_edges
apply_edges
"""
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
ntid
=
self
.
get_ntype_id
(
ntype
)
ntid
=
self
.
get_ntype_id
(
ntype
)
ntype
=
self
.
ntypes
[
ntid
]
ntype
=
self
.
ntypes
[
ntid
]
if
is_all
(
v
):
if
is_all
(
v
):
...
@@ -4351,7 +4260,7 @@ class DGLGraph(object):
...
@@ -4351,7 +4260,7 @@ class DGLGraph(object):
ndata
=
core
.
invoke_node_udf
(
self
,
v_id
,
ntype
,
func
,
orig_nid
=
v_id
)
ndata
=
core
.
invoke_node_udf
(
self
,
v_id
,
ntype
,
func
,
orig_nid
=
v_id
)
self
.
_set_n_repr
(
ntid
,
v
,
ndata
)
self
.
_set_n_repr
(
ntid
,
v
,
ndata
)
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
def
apply_edges
(
self
,
func
,
edges
=
ALL
,
etype
=
None
):
"""Update the features of the specified edges by the provided function.
"""Update the features of the specified edges by the provided function.
Parameters
Parameters
...
@@ -4382,9 +4291,6 @@ class DGLGraph(object):
...
@@ -4382,9 +4291,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
Notes
-----
-----
DGL recommends using DGL's bulit-in function for the :attr:`func` argument,
DGL recommends using DGL's bulit-in function for the :attr:`func` argument,
...
@@ -4435,8 +4341,6 @@ class DGLGraph(object):
...
@@ -4435,8 +4341,6 @@ class DGLGraph(object):
--------
--------
apply_nodes
apply_nodes
"""
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
# Graph with one relation type
# Graph with one relation type
if
self
.
_graph
.
number_of_etypes
()
==
1
or
etype
is
not
None
:
if
self
.
_graph
.
number_of_etypes
()
==
1
or
etype
is
not
None
:
etid
=
self
.
get_etype_id
(
etype
)
etid
=
self
.
get_etype_id
(
etype
)
...
@@ -4476,8 +4380,7 @@ class DGLGraph(object):
...
@@ -4476,8 +4380,7 @@ class DGLGraph(object):
message_func
,
message_func
,
reduce_func
,
reduce_func
,
apply_node_func
=
None
,
apply_node_func
=
None
,
etype
=
None
,
etype
=
None
):
inplace
=
False
):
"""Send messages along the specified edges and reduce them on
"""Send messages along the specified edges and reduce them on
the destination nodes to update their features.
the destination nodes to update their features.
...
@@ -4513,9 +4416,6 @@ class DGLGraph(object):
...
@@ -4513,9 +4416,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
Notes
-----
-----
DGL recommends using DGL's bulit-in function for the :attr:`message_func`
DGL recommends using DGL's bulit-in function for the :attr:`message_func`
...
@@ -4558,7 +4458,7 @@ class DGLGraph(object):
...
@@ -4558,7 +4458,7 @@ class DGLGraph(object):
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])
... })
... })
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.send_and_recv(g['follows'].edges(), fn.copy_
src
('h', 'm'),
>>> g.send_and_recv(g['follows'].edges(), fn.copy_
u
('h', 'm'),
... fn.sum('m', 'h'), etype='follows')
... fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[0.],
tensor([[0.],
...
@@ -4588,8 +4488,6 @@ class DGLGraph(object):
...
@@ -4588,8 +4488,6 @@ class DGLGraph(object):
Note that the feature of node 0 remains the same as it has no incoming edges.
Note that the feature of node 0 remains the same as it has no incoming edges.
"""
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
# edge type
# edge type
etid
=
self
.
get_etype_id
(
etype
)
etid
=
self
.
get_etype_id
(
etype
)
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
...
@@ -4612,8 +4510,7 @@ class DGLGraph(object):
...
@@ -4612,8 +4510,7 @@ class DGLGraph(object):
message_func
,
message_func
,
reduce_func
,
reduce_func
,
apply_node_func
=
None
,
apply_node_func
=
None
,
etype
=
None
,
etype
=
None
):
inplace
=
False
):
"""Pull messages from the specified node(s)' predecessors along the
"""Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features.
specified edge type, aggregate them to update the node features.
...
@@ -4645,9 +4542,6 @@ class DGLGraph(object):
...
@@ -4645,9 +4542,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
Notes
-----
-----
* If some of the given nodes :attr:`v` has no in-edges, DGL does not invoke
* If some of the given nodes :attr:`v` has no in-edges, DGL does not invoke
...
@@ -4688,14 +4582,12 @@ class DGLGraph(object):
...
@@ -4688,14 +4582,12 @@ class DGLGraph(object):
Pull.
Pull.
>>> g['follows'].pull(2, fn.copy_
src
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g['follows'].pull(2, fn.copy_
u
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[0.],
tensor([[0.],
[1.],
[1.],
[1.]])
[1.]])
"""
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
len
(
v
)
==
0
:
if
len
(
v
)
==
0
:
# no computation
# no computation
...
@@ -4716,8 +4608,7 @@ class DGLGraph(object):
...
@@ -4716,8 +4608,7 @@ class DGLGraph(object):
message_func
,
message_func
,
reduce_func
,
reduce_func
,
apply_node_func
=
None
,
apply_node_func
=
None
,
etype
=
None
,
etype
=
None
):
inplace
=
False
):
"""Send message from the specified node(s) to their successors
"""Send message from the specified node(s) to their successors
along the specified edge type and update their node features.
along the specified edge type and update their node features.
...
@@ -4749,9 +4640,6 @@ class DGLGraph(object):
...
@@ -4749,9 +4640,6 @@ class DGLGraph(object):
Can be omitted if the graph has only one type of edges.
Can be omitted if the graph has only one type of edges.
inplace: bool, optional
**DEPRECATED**.
Notes
Notes
-----
-----
DGL recommends using DGL's bulit-in function for the :attr:`message_func`
DGL recommends using DGL's bulit-in function for the :attr:`message_func`
...
@@ -4785,14 +4673,12 @@ class DGLGraph(object):
...
@@ -4785,14 +4673,12 @@ class DGLGraph(object):
Push.
Push.
>>> g['follows'].push(0, fn.copy_
src
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g['follows'].push(0, fn.copy_
u
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[0.],
tensor([[0.],
[0.],
[0.],
[0.]])
[0.]])
"""
"""
if
inplace
:
raise
DGLError
(
'The `inplace` option is removed in v0.5.'
)
edges
=
self
.
out_edges
(
u
,
form
=
'eid'
,
etype
=
etype
)
edges
=
self
.
out_edges
(
u
,
form
=
'eid'
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
...
@@ -4864,7 +4750,7 @@ class DGLGraph(object):
...
@@ -4864,7 +4750,7 @@ class DGLGraph(object):
Update all.
Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g['follows'].update_all(fn.copy_
src
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g['follows'].update_all(fn.copy_
u
('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[0.],
tensor([[0.],
[0.],
[0.],
...
@@ -4881,7 +4767,7 @@ class DGLGraph(object):
...
@@ -4881,7 +4767,7 @@ class DGLGraph(object):
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.update_all(fn.copy_
src
('h', 'm'), fn.sum('m', 'h'))
>>> g.update_all(fn.copy_
u
('h', 'm'), fn.sum('m', 'h'))
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[0.],
tensor([[0.],
[4.]])
[4.]])
...
@@ -4989,8 +4875,8 @@ class DGLGraph(object):
...
@@ -4989,8 +4875,8 @@ class DGLGraph(object):
Update all.
Update all.
>>> g.multi_update_all(
>>> g.multi_update_all(
... {'follows': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h')),
... {'follows': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h'))},
... 'attracts': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h'))},
... "sum")
... "sum")
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[0.],
tensor([[0.],
...
@@ -5004,8 +4890,8 @@ class DGLGraph(object):
...
@@ -5004,8 +4890,8 @@ class DGLGraph(object):
Use the user-defined cross reducer.
Use the user-defined cross reducer.
>>> g.multi_update_all(
>>> g.multi_update_all(
... {'follows': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h')),
... {'follows': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h')),
... 'attracts': (fn.copy_
src
('h', 'm'), fn.sum('m', 'h'))},
... 'attracts': (fn.copy_
u
('h', 'm'), fn.sum('m', 'h'))},
... cross_sum)
... cross_sum)
"""
"""
all_out
=
defaultdict
(
list
)
all_out
=
defaultdict
(
list
)
...
@@ -5088,7 +4974,7 @@ class DGLGraph(object):
...
@@ -5088,7 +4974,7 @@ class DGLGraph(object):
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_
src
('h', 'm'),
>>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_
u
('h', 'm'),
... fn.sum('m', 'h'), etype='follows')
... fn.sum('m', 'h'), etype='follows')
tensor([[1.],
tensor([[1.],
[2.],
[2.],
...
@@ -5151,7 +5037,7 @@ class DGLGraph(object):
...
@@ -5151,7 +5037,7 @@ class DGLGraph(object):
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])
>>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_
src
('h', 'm'),
>>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_
u
('h', 'm'),
... fn.sum('m', 'h'), etype='follows')
... fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
>>> g.nodes['user'].data['h']
tensor([[1.],
tensor([[1.],
...
@@ -6062,113 +5948,6 @@ class DGLGraph(object):
...
@@ -6062,113 +5948,6 @@ class DGLGraph(object):
"""
"""
return
self
.
astype
(
F
.
int32
)
return
self
.
astype
(
F
.
int32
)
#################################################################
# DEPRECATED: from the old DGLGraph
#################################################################
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
"""DEPRECATED: please use
``dgl.from_networkx(nx_graph, node_attrs, edge_attrs)``
which will return a new graph created from the networkx graph.
"""
raise
DGLError
(
'DGLGraph.from_networkx is deprecated. Please call the following
\n\n
'
'
\t
dgl.from_networkx(nx_graph, node_attrs, edge_attrs)
\n\n
'
', which creates a new DGLGraph from the networkx graph.'
)
def
from_scipy_sparse_matrix
(
self
,
spmat
,
multigraph
=
None
):
"""DEPRECATED: please use
``dgl.from_scipy(spmat)``
which will return a new graph created from the scipy matrix.
"""
raise
DGLError
(
'DGLGraph.from_scipy_sparse_matrix is deprecated. '
'Please call the following
\n\n
'
'
\t
dgl.from_scipy(spmat)
\n\n
'
', which creates a new DGLGraph from the scipy matrix.'
)
def
register_apply_node_func
(
self
,
func
):
"""Deprecated: please directly call :func:`apply_nodes` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_apply_node_func is deprecated.'
' Please directly call apply_nodes with func as the argument.'
)
def
register_apply_edge_func
(
self
,
func
):
"""Deprecated: please directly call :func:`apply_edges` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_apply_edge_func is deprecated.'
' Please directly call apply_edges with func as the argument.'
)
def
register_message_func
(
self
,
func
):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_message_func is deprecated.'
' Please directly call update_all with func as the argument.'
)
def
register_reduce_func
(
self
,
func
):
"""Deprecated: please directly call :func:`update_all` with ``func``
as argument.
"""
raise
DGLError
(
'DGLGraph.register_reduce_func is deprecated.'
' Please directly call update_all with func as the argument.'
)
def
group_apply_edges
(
self
,
group_by
,
func
,
edges
=
ALL
,
etype
=
None
,
inplace
=
False
):
"""**DEPRECATED**: The API is removed in 0.5."""
raise
DGLError
(
'DGLGraph.group_apply_edges is removed in 0.5.'
)
def
send
(
self
,
edges
,
message_func
,
etype
=
None
):
"""Send messages along the given edges with the same edge type.
DEPRECATE: please use send_and_recv, update_all.
"""
raise
DGLError
(
'DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges
\n
'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv
\n
'
' and set the message function as dgl.function.copy_e to conduct message
\n
'
' aggregation.'
)
def
recv
(
self
,
v
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
inplace
=
False
):
r
"""Receive and reduce incoming messages and update the features of node(s) :math:`v`.
DEPRECATE: please use send_and_recv, update_all.
"""
raise
DGLError
(
'DGLGraph.recv is deprecated. As a replacement, use DGLGraph.apply_edges
\n
'
' API to compute messages as edge data. Then use DGLGraph.send_and_recv
\n
'
' and set the message function as dgl.function.copy_e to conduct message
\n
'
' aggregation.'
)
def
multi_recv
(
self
,
v
,
reducer_dict
,
cross_reducer
,
apply_node_func
=
None
,
inplace
=
False
):
r
"""Receive messages from multiple edge types and perform aggregation.
DEPRECATE: please use multi_send_and_recv, multi_update_all.
"""
raise
DGLError
(
'DGLGraph.multi_recv is deprecated. As a replacement,
\n
'
' use DGLGraph.apply_edges API to compute messages as edge data.
\n
'
' Then use DGLGraph.multi_send_and_recv and set the message function
\n
'
' as dgl.function.copy_e to conduct message aggregation.'
)
def
multi_send_and_recv
(
self
,
etype_dict
,
cross_reducer
,
apply_node_func
=
None
,
inplace
=
False
):
r
"""**DEPRECATED**: The API is removed in v0.5."""
raise
DGLError
(
'DGLGraph.multi_pull is removed in v0.5. As a replacement,
\n
'
' use DGLGraph.edge_subgraph to extract the subgraph first
\n
'
' and then call DGLGraph.multi_update_all.'
)
def
multi_pull
(
self
,
v
,
etype_dict
,
cross_reducer
,
apply_node_func
=
None
,
inplace
=
False
):
r
"""**DEPRECATED**: The API is removed in v0.5."""
raise
DGLError
(
'DGLGraph.multi_pull is removed in v0.5. As a replacement,
\n
'
' use DGLGraph.edge_subgraph to extract the subgraph first
\n
'
' and then call DGLGraph.multi_update_all.'
)
def
readonly
(
self
,
readonly_state
=
True
):
"""Deprecated: DGLGraph will always be mutable."""
dgl_warning
(
'DGLGraph.readonly is deprecated in v0.5.
\n
'
'DGLGraph now always supports mutable operations like add_nodes'
' and add_edges.'
)
############################################################
############################################################
# Internal APIs
# Internal APIs
############################################################
############################################################
...
...
python/dgl/nn/mxnet/conv/graphconv.py
View file @
61139302
...
@@ -261,13 +261,13 @@ class GraphConv(gluon.Block):
...
@@ -261,13 +261,13 @@ class GraphConv(gluon.Block):
if
weight
is
not
None
:
if
weight
is
not
None
:
feat_src
=
mx
.
nd
.
dot
(
feat_src
,
weight
)
feat_src
=
mx
.
nd
.
dot
(
feat_src
,
weight
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
.
pop
(
'h'
)
rst
=
graph
.
dstdata
.
pop
(
'h'
)
else
:
else
:
# aggregate first then mult W
# aggregate first then mult W
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
.
pop
(
'h'
)
rst
=
graph
.
dstdata
.
pop
(
'h'
)
if
weight
is
not
None
:
if
weight
is
not
None
:
...
...
python/dgl/nn/mxnet/conv/tagconv.py
View file @
61139302
...
@@ -114,7 +114,7 @@ class TAGConv(gluon.Block):
...
@@ -114,7 +114,7 @@ class TAGConv(gluon.Block):
rst
=
rst
*
norm
rst
=
rst
*
norm
graph
.
ndata
[
'h'
]
=
rst
graph
.
ndata
[
'h'
]
=
rst
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
ndata
[
'h'
]
rst
=
graph
.
ndata
[
'h'
]
rst
=
rst
*
norm
rst
=
rst
*
norm
...
...
python/dgl/nn/pytorch/conv/ginconv.py
View file @
61139302
...
@@ -136,7 +136,7 @@ class GINConv(nn.Module):
...
@@ -136,7 +136,7 @@ class GINConv(nn.Module):
"""
"""
_reducer
=
getattr
(
fn
,
self
.
_aggregator_type
)
_reducer
=
getattr
(
fn
,
self
.
_aggregator_type
)
with
graph
.
local_scope
():
with
graph
.
local_scope
():
aggregate_fn
=
fn
.
copy_
src
(
'h'
,
'm'
)
aggregate_fn
=
fn
.
copy_
u
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
if
edge_weight
is
not
None
:
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
...
...
python/dgl/nn/pytorch/conv/graphconv.py
View file @
61139302
...
@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module):
...
@@ -114,13 +114,13 @@ class EdgeWeightNorm(nn.Module):
if
self
.
_norm
==
'both'
:
if
self
.
_norm
==
'both'
:
reversed_g
=
reverse
(
graph
)
reversed_g
=
reverse
(
graph
)
reversed_g
.
edata
[
'_edge_w'
]
=
edge_weight
reversed_g
.
edata
[
'_edge_w'
]
=
edge_weight
reversed_g
.
update_all
(
fn
.
copy_e
dge
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'out_weight'
))
reversed_g
.
update_all
(
fn
.
copy_e
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'out_weight'
))
degs
=
reversed_g
.
dstdata
[
'out_weight'
]
+
self
.
_eps
degs
=
reversed_g
.
dstdata
[
'out_weight'
]
+
self
.
_eps
norm
=
th
.
pow
(
degs
,
-
0.5
)
norm
=
th
.
pow
(
degs
,
-
0.5
)
graph
.
srcdata
[
'_src_out_w'
]
=
norm
graph
.
srcdata
[
'_src_out_w'
]
=
norm
if
self
.
_norm
!=
'none'
:
if
self
.
_norm
!=
'none'
:
graph
.
update_all
(
fn
.
copy_e
dge
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'in_weight'
))
graph
.
update_all
(
fn
.
copy_e
(
'_edge_w'
,
'm'
),
fn
.
sum
(
'm'
,
'in_weight'
))
degs
=
graph
.
dstdata
[
'in_weight'
]
+
self
.
_eps
degs
=
graph
.
dstdata
[
'in_weight'
]
+
self
.
_eps
if
self
.
_norm
==
'both'
:
if
self
.
_norm
==
'both'
:
norm
=
th
.
pow
(
degs
,
-
0.5
)
norm
=
th
.
pow
(
degs
,
-
0.5
)
...
@@ -389,7 +389,7 @@ class GraphConv(nn.Module):
...
@@ -389,7 +389,7 @@ class GraphConv(nn.Module):
'the issue. Setting ``allow_zero_in_degree`` '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'to be `True` when constructing this module will '
'suppress the check and let the code run.'
)
'suppress the check and let the code run.'
)
aggregate_fn
=
fn
.
copy_
src
(
'h'
,
'm'
)
aggregate_fn
=
fn
.
copy_
u
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
if
edge_weight
is
not
None
:
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
...
...
python/dgl/nn/pytorch/conv/sageconv.py
View file @
61139302
...
@@ -213,7 +213,7 @@ class SAGEConv(nn.Module):
...
@@ -213,7 +213,7 @@ class SAGEConv(nn.Module):
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
feat_src
=
feat_dst
=
self
.
feat_drop
(
feat
)
if
graph
.
is_block
:
if
graph
.
is_block
:
feat_dst
=
feat_src
[:
graph
.
number_of_dst_nodes
()]
feat_dst
=
feat_src
[:
graph
.
number_of_dst_nodes
()]
msg_fn
=
fn
.
copy_
src
(
'h'
,
'm'
)
msg_fn
=
fn
.
copy_
u
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
if
edge_weight
is
not
None
:
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
assert
edge_weight
.
shape
[
0
]
==
graph
.
number_of_edges
()
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
graph
.
edata
[
'_edge_weight'
]
=
edge_weight
...
...
python/dgl/nn/pytorch/sparse_emb.py
View file @
61139302
...
@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding
...
@@ -334,19 +334,6 @@ class NodeEmbedding: # NodeEmbedding
"""
"""
self
.
_trace
=
[]
self
.
_trace
=
[]
@
property
def
emb_tensor
(
self
):
"""Return the tensor storing the node embeddings
DEPRECATED: renamed weight
Returns
-------
torch.Tensor
The tensor storing the node embeddings
"""
return
self
.
_tensor
@
property
@
property
def
weight
(
self
):
def
weight
(
self
):
"""Return the tensor storing the node embeddings
"""Return the tensor storing the node embeddings
...
...
python/dgl/nn/tensorflow/conv/graphconv.py
View file @
61139302
...
@@ -253,13 +253,13 @@ class GraphConv(layers.Layer):
...
@@ -253,13 +253,13 @@ class GraphConv(layers.Layer):
if
weight
is
not
None
:
if
weight
is
not
None
:
feat_src
=
tf
.
matmul
(
feat_src
,
weight
)
feat_src
=
tf
.
matmul
(
feat_src
,
weight
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
[
'h'
]
rst
=
graph
.
dstdata
[
'h'
]
else
:
else
:
# aggregate first then mult W
# aggregate first then mult W
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
src
=
'h'
,
out
=
'm'
),
graph
.
update_all
(
fn
.
copy_
u
(
u
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
))
rst
=
graph
.
dstdata
[
'h'
]
rst
=
graph
.
dstdata
[
'h'
]
if
weight
is
not
None
:
if
weight
is
not
None
:
...
...
python/dgl/nn/tensorflow/conv/sageconv.py
View file @
61139302
...
@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer):
...
@@ -166,24 +166,24 @@ class SAGEConv(layers.Layer):
if
self
.
_aggre_type
==
'mean'
:
if
self
.
_aggre_type
==
'mean'
:
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
fn
.
mean
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
dstdata
[
'neigh'
]
h_neigh
=
graph
.
dstdata
[
'neigh'
]
elif
self
.
_aggre_type
==
'gcn'
:
elif
self
.
_aggre_type
==
'gcn'
:
check_eq_shape
(
feat
)
check_eq_shape
(
feat
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
dstdata
[
'h'
]
=
feat_dst
# same as above if homogeneous
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'neigh'
))
# divide in_degrees
# divide in_degrees
degs
=
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
)
degs
=
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
)
h_neigh
=
(
graph
.
dstdata
[
'neigh'
]
+
graph
.
dstdata
[
'h'
]
h_neigh
=
(
graph
.
dstdata
[
'neigh'
]
+
graph
.
dstdata
[
'h'
]
)
/
(
tf
.
expand_dims
(
degs
,
-
1
)
+
1
)
)
/
(
tf
.
expand_dims
(
degs
,
-
1
)
+
1
)
elif
self
.
_aggre_type
==
'pool'
:
elif
self
.
_aggre_type
==
'pool'
:
graph
.
srcdata
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat_src
))
graph
.
srcdata
[
'h'
]
=
tf
.
nn
.
relu
(
self
.
fc_pool
(
feat_src
))
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
fn
.
max
(
'm'
,
'neigh'
))
h_neigh
=
graph
.
dstdata
[
'neigh'
]
h_neigh
=
graph
.
dstdata
[
'neigh'
]
elif
self
.
_aggre_type
==
'lstm'
:
elif
self
.
_aggre_type
==
'lstm'
:
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
fn
.
copy_
src
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
graph
.
update_all
(
fn
.
copy_
u
(
'h'
,
'm'
),
self
.
_lstm_reducer
)
h_neigh
=
graph
.
dstdata
[
'neigh'
]
h_neigh
=
graph
.
dstdata
[
'neigh'
]
else
:
else
:
raise
KeyError
(
raise
KeyError
(
...
...
python/dgl/optim/pytorch/sparse_optim.py
View file @
61139302
...
@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer):
...
@@ -526,7 +526,7 @@ class SparseAdagrad(SparseGradOptimizer):
),
"SparseAdagrad only supports dgl.nn.NodeEmbedding"
),
"SparseAdagrad only supports dgl.nn.NodeEmbedding"
emb_name
=
emb
.
name
emb_name
=
emb
.
name
if
th
.
device
(
emb
.
emb_tensor
.
device
)
==
th
.
device
(
"cpu"
):
if
th
.
device
(
emb
.
weight
.
device
)
==
th
.
device
(
"cpu"
):
# if our embedding is on the CPU, our state also has to be
# if our embedding is on the CPU, our state also has to be
if
self
.
_rank
<
0
:
if
self
.
_rank
<
0
:
state
=
th
.
empty
(
state
=
th
.
empty
(
...
@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer):
...
@@ -550,9 +550,9 @@ class SparseAdagrad(SparseGradOptimizer):
else
:
else
:
# distributed state on on gpu
# distributed state on on gpu
state
=
th
.
empty
(
state
=
th
.
empty
(
emb
.
emb_tensor
.
shape
,
emb
.
weight
.
shape
,
dtype
=
th
.
float32
,
dtype
=
th
.
float32
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
).
zero_
()
emb
.
set_optm_state
(
state
)
emb
.
set_optm_state
(
state
)
...
@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer):
...
@@ -689,7 +689,7 @@ class SparseAdam(SparseGradOptimizer):
),
"SparseAdam only supports dgl.nn.NodeEmbedding"
),
"SparseAdam only supports dgl.nn.NodeEmbedding"
emb_name
=
emb
.
name
emb_name
=
emb
.
name
self
.
_is_using_uva
[
emb_name
]
=
self
.
_use_uva
self
.
_is_using_uva
[
emb_name
]
=
self
.
_use_uva
if
th
.
device
(
emb
.
emb_tensor
.
device
)
==
th
.
device
(
"cpu"
):
if
th
.
device
(
emb
.
weight
.
device
)
==
th
.
device
(
"cpu"
):
# if our embedding is on the CPU, our state also has to be
# if our embedding is on the CPU, our state also has to be
if
self
.
_rank
<
0
:
if
self
.
_rank
<
0
:
state_step
=
th
.
empty
(
state_step
=
th
.
empty
(
...
@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer):
...
@@ -743,19 +743,19 @@ class SparseAdam(SparseGradOptimizer):
# distributed state on on gpu
# distributed state on on gpu
state_step
=
th
.
empty
(
state_step
=
th
.
empty
(
[
emb
.
emb_tensor
.
shape
[
0
]],
[
emb
.
weight
.
shape
[
0
]],
dtype
=
th
.
int32
,
dtype
=
th
.
int32
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
).
zero_
()
state_mem
=
th
.
empty
(
state_mem
=
th
.
empty
(
emb
.
emb_tensor
.
shape
,
emb
.
weight
.
shape
,
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
).
zero_
()
state_power
=
th
.
empty
(
state_power
=
th
.
empty
(
emb
.
emb_tensor
.
shape
,
emb
.
weight
.
shape
,
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
device
=
emb
.
emb_tensor
.
device
,
device
=
emb
.
weight
.
device
,
).
zero_
()
).
zero_
()
state
=
(
state_step
,
state_mem
,
state_power
)
state
=
(
state_step
,
state_mem
,
state_power
)
emb
.
set_optm_state
(
state
)
emb
.
set_optm_state
(
state
)
...
...
tests/compute/test_basics.py
View file @
61139302
...
@@ -32,10 +32,10 @@ def generate_graph_old(grad=False):
...
@@ -32,10 +32,10 @@ def generate_graph_old(grad=False):
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
# 17 edges
# 17 edges
for
i
in
range
(
1
,
9
):
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
g
.
add_edge
s
(
i
,
9
)
# add a back flow from 9 to 0
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
g
.
add_edge
s
(
9
,
0
)
g
=
g
.
to
(
F
.
ctx
())
g
=
g
.
to
(
F
.
ctx
())
ncol
=
F
.
randn
((
10
,
D
))
ncol
=
F
.
randn
((
10
,
D
))
ecol
=
F
.
randn
((
17
,
D
))
ecol
=
F
.
randn
((
17
,
D
))
...
@@ -431,8 +431,8 @@ def test_dynamic_addition():
...
@@ -431,8 +431,8 @@ def test_dynamic_addition():
assert
g
.
ndata
[
'h1'
].
shape
[
0
]
==
g
.
ndata
[
'h2'
].
shape
[
0
]
==
N
+
3
assert
g
.
ndata
[
'h1'
].
shape
[
0
]
==
g
.
ndata
[
'h2'
].
shape
[
0
]
==
N
+
3
# Test edge addition
# Test edge addition
g
.
add_edge
(
0
,
1
)
g
.
add_edge
s
(
0
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
s
(
1
,
0
)
g
.
edata
.
update
({
'h1'
:
F
.
randn
((
2
,
D
)),
g
.
edata
.
update
({
'h1'
:
F
.
randn
((
2
,
D
)),
'h2'
:
F
.
randn
((
2
,
D
))})
'h2'
:
F
.
randn
((
2
,
D
))})
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
2
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
2
...
@@ -441,12 +441,12 @@ def test_dynamic_addition():
...
@@ -441,12 +441,12 @@ def test_dynamic_addition():
g
.
edata
[
'h1'
]
=
F
.
randn
((
4
,
D
))
g
.
edata
[
'h1'
]
=
F
.
randn
((
4
,
D
))
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
4
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
4
g
.
add_edge
(
1
,
2
)
g
.
add_edge
s
(
1
,
2
)
g
.
edges
[
4
].
data
[
'h1'
]
=
F
.
randn
((
1
,
D
))
g
.
edges
[
4
].
data
[
'h1'
]
=
F
.
randn
((
1
,
D
))
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
5
assert
g
.
edata
[
'h1'
].
shape
[
0
]
==
g
.
edata
[
'h2'
].
shape
[
0
]
==
5
# test add edge with part of the features
# test add edge with part of the features
g
.
add_edge
(
2
,
1
,
{
'h1'
:
F
.
randn
((
1
,
D
))})
g
.
add_edge
s
(
2
,
1
,
{
'h1'
:
F
.
randn
((
1
,
D
))})
assert
len
(
g
.
edata
[
'h1'
])
==
len
(
g
.
edata
[
'h2'
])
assert
len
(
g
.
edata
[
'h1'
])
==
len
(
g
.
edata
[
'h2'
])
...
...
tests/compute/test_batched_graph.py
View file @
61139302
...
@@ -15,10 +15,10 @@ def tree1(idtype):
...
@@ -15,10 +15,10 @@ def tree1(idtype):
"""
"""
g
=
dgl
.
graph
(([],
[])).
astype
(
idtype
).
to
(
F
.
ctx
())
g
=
dgl
.
graph
(([],
[])).
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
5
)
g
.
add_nodes
(
5
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
s
(
3
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
s
(
4
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
s
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
s
(
2
,
0
)
g
.
ndata
[
'h'
]
=
F
.
tensor
([
0
,
1
,
2
,
3
,
4
])
g
.
ndata
[
'h'
]
=
F
.
tensor
([
0
,
1
,
2
,
3
,
4
])
g
.
edata
[
'h'
]
=
F
.
randn
((
4
,
10
))
g
.
edata
[
'h'
]
=
F
.
randn
((
4
,
10
))
return
g
return
g
...
@@ -34,10 +34,10 @@ def tree2(idtype):
...
@@ -34,10 +34,10 @@ def tree2(idtype):
"""
"""
g
=
dgl
.
graph
(([],
[])).
astype
(
idtype
).
to
(
F
.
ctx
())
g
=
dgl
.
graph
(([],
[])).
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
5
)
g
.
add_nodes
(
5
)
g
.
add_edge
(
2
,
4
)
g
.
add_edge
s
(
2
,
4
)
g
.
add_edge
(
0
,
4
)
g
.
add_edge
s
(
0
,
4
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
s
(
4
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
s
(
3
,
1
)
g
.
ndata
[
'h'
]
=
F
.
tensor
([
0
,
1
,
2
,
3
,
4
])
g
.
ndata
[
'h'
]
=
F
.
tensor
([
0
,
1
,
2
,
3
,
4
])
g
.
edata
[
'h'
]
=
F
.
randn
((
4
,
10
))
g
.
edata
[
'h'
]
=
F
.
randn
((
4
,
10
))
return
g
return
g
...
@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype):
...
@@ -191,8 +191,8 @@ def test_batched_edge_ordering(idtype):
e2
=
F
.
randn
((
6
,
10
))
e2
=
F
.
randn
((
6
,
10
))
g2
.
edata
[
'h'
]
=
e2
g2
.
edata
[
'h'
]
=
e2
g
=
dgl
.
batch
([
g1
,
g2
])
g
=
dgl
.
batch
([
g1
,
g2
])
r1
=
g
.
edata
[
'h'
][
g
.
edge_id
(
4
,
5
)]
r1
=
g
.
edata
[
'h'
][
g
.
edge_id
s
(
4
,
5
)]
r2
=
g1
.
edata
[
'h'
][
g1
.
edge_id
(
4
,
5
)]
r2
=
g1
.
edata
[
'h'
][
g1
.
edge_id
s
(
4
,
5
)]
assert
F
.
array_equal
(
r1
,
r2
)
assert
F
.
array_equal
(
r1
,
r2
)
@
parametrize_idtype
@
parametrize_idtype
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment