Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
61139302
Unverified
Commit
61139302
authored
Dec 01, 2022
by
peizhou001
Committed by
GitHub
Dec 01, 2022
Browse files
[API Deprecation] Remove candidates in DGLGraph (#4946)
parent
e088acac
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
112 additions
and
130 deletions
+112
-130
tests/compute/test_batched_heterograph.py
tests/compute/test_batched_heterograph.py
+0
-4
tests/compute/test_graph.py
tests/compute/test_graph.py
+34
-39
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+5
-5
tests/compute/test_kernel.py
tests/compute/test_kernel.py
+12
-12
tests/compute/test_pickle.py
tests/compute/test_pickle.py
+0
-1
tests/compute/test_propagate.py
tests/compute/test_propagate.py
+4
-4
tests/compute/test_removal.py
tests/compute/test_removal.py
+13
-13
tests/compute/test_sampler.py
tests/compute/test_sampler.py
+3
-3
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+1
-1
tests/compute/test_shared_mem.py
tests/compute/test_shared_mem.py
+0
-1
tests/compute/test_specialization.py
tests/compute/test_specialization.py
+19
-19
tests/compute/test_subgraph.py
tests/compute/test_subgraph.py
+3
-3
tests/compute/test_transform.py
tests/compute/test_transform.py
+7
-7
tests/compute/utils.py
tests/compute/utils.py
+0
-2
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+0
-5
tests/distributed/test_partition.py
tests/distributed/test_partition.py
+4
-4
tests/distributed/test_shared_mem_store.py
tests/distributed/test_shared_mem_store.py
+3
-3
tutorials/models/1_gnn/6_line_graph.py
tutorials/models/1_gnn/6_line_graph.py
+2
-2
tutorials/models/2_small_graph/3_tree-lstm.py
tutorials/models/2_small_graph/3_tree-lstm.py
+1
-1
tutorials/models/3_generative_model/5_dgmg.py
tutorials/models/3_generative_model/5_dgmg.py
+1
-1
No files found.
tests/compute/test_batched_heterograph.py
View file @
61139302
...
...
@@ -237,10 +237,6 @@ def test_features(idtype):
node_attrs
=
{
'user'
:
[
'h2'
],
'game'
:
[
'h2'
]},
edge_attrs
=
{(
'user'
,
'follows'
,
'user'
):
[
'h1'
]})
# test legacy
bg
=
dgl
.
batch
([
g1
,
g2
],
edge_attrs
=
[
'h1'
])
assert
'h2'
not
in
bg
.
edges
[
'follows'
].
data
.
keys
()
@
unittest
.
skipIf
(
F
.
backend_name
==
'mxnet'
,
reason
=
"MXNet does not support split array with zero-length segment."
)
@
parametrize_idtype
...
...
tests/compute/test_graph.py
View file @
61139302
...
...
@@ -61,21 +61,19 @@ def test_query():
assert
g
.
number_of_edges
()
==
20
for
i
in
range
(
10
):
assert
g
.
has_node
(
i
)
assert
i
in
g
assert
not
g
.
has_node
(
11
)
assert
not
11
in
g
assert
g
.
has_nodes
(
i
)
assert
not
g
.
has_nodes
(
11
)
assert
F
.
allclose
(
g
.
has_nodes
([
0
,
2
,
10
,
11
]),
F
.
tensor
([
1
,
1
,
0
,
0
]))
src
,
dst
=
edge_pair_input
()
for
u
,
v
in
zip
(
src
,
dst
):
assert
g
.
has_edge_between
(
u
,
v
)
assert
not
g
.
has_edge_between
(
0
,
0
)
assert
g
.
has_edge
s
_between
(
u
,
v
)
assert
not
g
.
has_edge
s
_between
(
0
,
0
)
assert
F
.
allclose
(
g
.
has_edges_between
([
0
,
0
,
3
],
[
0
,
9
,
8
]),
F
.
tensor
([
0
,
1
,
1
]))
assert
set
(
F
.
asnumpy
(
g
.
predecessors
(
9
)))
==
set
([
0
,
5
,
7
,
4
])
assert
set
(
F
.
asnumpy
(
g
.
successors
(
2
)))
==
set
([
7
,
3
])
assert
g
.
edge_id
(
4
,
4
)
==
5
assert
g
.
edge_id
s
(
4
,
4
)
==
5
assert
F
.
allclose
(
g
.
edge_ids
([
4
,
0
],
[
4
,
9
]),
F
.
tensor
([
5
,
0
]))
src
,
dst
=
g
.
find_edges
([
3
,
6
,
5
])
...
...
@@ -110,11 +108,11 @@ def test_query():
assert
set
(
tup
)
==
set
(
t_tup
)
assert
list
(
F
.
asnumpy
(
src
))
==
sorted
(
list
(
F
.
asnumpy
(
src
)))
assert
g
.
in_degree
(
0
)
==
0
assert
g
.
in_degree
(
9
)
==
4
assert
g
.
in_degree
s
(
0
)
==
0
assert
g
.
in_degree
s
(
9
)
==
4
assert
F
.
allclose
(
g
.
in_degrees
([
0
,
9
]),
F
.
tensor
([
0
,
4
]))
assert
g
.
out_degree
(
8
)
==
0
assert
g
.
out_degree
(
9
)
==
1
assert
g
.
out_degree
s
(
8
)
==
0
assert
g
.
out_degree
s
(
9
)
==
1
assert
F
.
allclose
(
g
.
out_degrees
([
8
,
9
]),
F
.
tensor
([
0
,
1
]))
assert
np
.
array_equal
(
...
...
@@ -132,16 +130,14 @@ def test_query():
assert
g
.
number_of_edges
()
==
20
for
i
in
range
(
10
):
assert
g
.
has_node
(
i
)
assert
i
in
g
assert
not
g
.
has_node
(
11
)
assert
not
11
in
g
assert
g
.
has_nodes
(
i
)
assert
not
g
.
has_nodes
(
11
)
assert
F
.
allclose
(
g
.
has_nodes
([
0
,
2
,
10
,
11
]),
F
.
tensor
([
1
,
1
,
0
,
0
]))
src
,
dst
=
edge_pair_input
(
sort
=
True
)
for
u
,
v
in
zip
(
src
,
dst
):
assert
g
.
has_edge_between
(
u
,
v
)
assert
not
g
.
has_edge_between
(
0
,
0
)
assert
g
.
has_edge
s
_between
(
u
,
v
)
assert
not
g
.
has_edge
s
_between
(
0
,
0
)
assert
F
.
allclose
(
g
.
has_edges_between
([
0
,
0
,
3
],
[
0
,
9
,
8
]),
F
.
tensor
([
0
,
1
,
1
]))
assert
set
(
F
.
asnumpy
(
g
.
predecessors
(
9
)))
==
set
([
0
,
5
,
7
,
4
])
assert
set
(
F
.
asnumpy
(
g
.
successors
(
2
)))
==
set
([
7
,
3
])
...
...
@@ -149,7 +145,7 @@ def test_query():
# src = [0 0 0 1 1 2 2 3 3 4 4 4 4 5 5 6 7 7 7 9]
# dst = [4 6 9 3 5 3 7 5 8 1 3 4 9 1 9 6 2 8 9 2]
# eid = [0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9]
assert
g
.
edge_id
(
4
,
4
)
==
11
assert
g
.
edge_id
s
(
4
,
4
)
==
11
assert
F
.
allclose
(
g
.
edge_ids
([
4
,
0
],
[
4
,
9
]),
F
.
tensor
([
11
,
2
]))
src
,
dst
=
g
.
find_edges
([
3
,
6
,
5
])
...
...
@@ -184,11 +180,11 @@ def test_query():
assert
set
(
tup
)
==
set
(
t_tup
)
assert
list
(
F
.
asnumpy
(
src
))
==
sorted
(
list
(
F
.
asnumpy
(
src
)))
assert
g
.
in_degree
(
0
)
==
0
assert
g
.
in_degree
(
9
)
==
4
assert
g
.
in_degree
s
(
0
)
==
0
assert
g
.
in_degree
s
(
9
)
==
4
assert
F
.
allclose
(
g
.
in_degrees
([
0
,
9
]),
F
.
tensor
([
0
,
4
]))
assert
g
.
out_degree
(
8
)
==
0
assert
g
.
out_degree
(
9
)
==
1
assert
g
.
out_degree
s
(
8
)
==
0
assert
g
.
out_degree
s
(
9
)
==
1
assert
F
.
allclose
(
g
.
out_degrees
([
8
,
9
]),
F
.
tensor
([
0
,
1
]))
assert
np
.
array_equal
(
...
...
@@ -205,17 +201,17 @@ def test_query():
g
=
gen_by_mutation
()
eids
=
g
.
edge_ids
([
4
,
0
],
[
4
,
9
])
assert
eids
.
shape
[
0
]
==
2
eid
=
g
.
edge_id
(
4
,
4
)
eid
=
g
.
edge_id
s
(
4
,
4
)
assert
isinstance
(
eid
,
numbers
.
Number
)
with
pytest
.
raises
(
DGLError
):
eids
=
g
.
edge_ids
([
9
,
0
],
[
4
,
9
])
with
pytest
.
raises
(
DGLError
):
eid
=
g
.
edge_id
(
4
,
5
)
eid
=
g
.
edge_id
s
(
4
,
5
)
g
.
add_edge
(
0
,
4
)
g
.
add_edge
s
(
0
,
4
)
eids
=
g
.
edge_ids
([
0
,
0
],
[
4
,
9
])
eid
=
g
.
edge_id
(
0
,
4
)
eid
=
g
.
edge_id
s
(
0
,
4
)
_test
(
gen_by_mutation
())
_test
(
gen_from_data
(
elist_input
(),
False
,
False
))
...
...
@@ -260,11 +256,11 @@ def test_scipy_adjmat():
def
test_incmat
():
g
=
dgl
.
DGLGraph
()
g
.
add_nodes
(
4
)
g
.
add_edge
(
0
,
1
)
# 0
g
.
add_edge
(
0
,
2
)
# 1
g
.
add_edge
(
0
,
3
)
# 2
g
.
add_edge
(
2
,
3
)
# 3
g
.
add_edge
(
1
,
1
)
# 4
g
.
add_edge
s
(
0
,
1
)
# 0
g
.
add_edge
s
(
0
,
2
)
# 1
g
.
add_edge
s
(
0
,
3
)
# 2
g
.
add_edge
s
(
2
,
3
)
# 3
g
.
add_edge
s
(
1
,
1
)
# 4
inc_in
=
F
.
sparse_to_numpy
(
g
.
incidence_matrix
(
'in'
))
inc_out
=
F
.
sparse_to_numpy
(
g
.
incidence_matrix
(
'out'
))
inc_both
=
F
.
sparse_to_numpy
(
g
.
incidence_matrix
(
'both'
))
...
...
@@ -323,18 +319,17 @@ def test_hypersparse_query():
g
.
add_nodes
(
1000001
)
g
.
add_edges
([
0
],
[
1
])
for
i
in
range
(
10
):
assert
g
.
has_node
(
i
)
assert
i
in
g
assert
not
g
.
has_node
(
1000002
)
assert
g
.
edge_id
(
0
,
1
)
==
0
assert
g
.
has_nodes
(
i
)
assert
not
g
.
has_nodes
(
1000002
)
assert
g
.
edge_ids
(
0
,
1
)
==
0
src
,
dst
=
g
.
find_edges
([
0
])
src
,
dst
,
eid
=
g
.
in_edges
(
1
,
form
=
'all'
)
src
,
dst
,
eid
=
g
.
out_edges
(
0
,
form
=
'all'
)
src
,
dst
=
g
.
edges
()
assert
g
.
in_degree
(
0
)
==
0
assert
g
.
in_degree
(
1
)
==
1
assert
g
.
out_degree
(
0
)
==
1
assert
g
.
out_degree
(
1
)
==
0
assert
g
.
in_degree
s
(
0
)
==
0
assert
g
.
in_degree
s
(
1
)
==
1
assert
g
.
out_degree
s
(
0
)
==
1
assert
g
.
out_degree
s
(
1
)
==
0
def
test_empty_data_initialized
():
g
=
dgl
.
DGLGraph
()
...
...
tests/compute/test_heterograph.py
View file @
61139302
...
...
@@ -269,12 +269,12 @@ def test_query(idtype):
# number of edges
assert
[
g
.
num_edges
(
etype
)
for
etype
in
etypes
]
==
[
2
,
4
,
2
,
2
]
#
has_node &
has_nodes
# has_nodes
for
ntype
in
ntypes
:
n
=
g
.
number_of_nodes
(
ntype
)
for
i
in
range
(
n
):
assert
g
.
has_node
(
i
,
ntype
)
assert
not
g
.
has_node
(
n
,
ntype
)
assert
g
.
has_node
s
(
i
,
ntype
)
assert
not
g
.
has_node
s
(
n
,
ntype
)
assert
np
.
array_equal
(
F
.
asnumpy
(
g
.
has_nodes
([
0
,
n
],
ntype
)).
astype
(
'int32'
),
[
1
,
0
])
...
...
@@ -310,7 +310,7 @@ def test_query(idtype):
assert
set
(
F
.
asnumpy
(
v
).
tolist
())
==
set
(
succ
)
assert
g
.
out_degrees
(
0
,
etype
)
==
len
(
succ
)
#
edge_id &
edge_ids
# edge_ids
for
i
,
(
src
,
dst
)
in
enumerate
(
zip
(
srcs
,
dsts
)):
assert
g
.
edge_ids
(
src
,
dst
,
etype
=
etype
)
==
i
_
,
_
,
eid
=
g
.
edge_ids
(
src
,
dst
,
etype
=
etype
,
return_uv
=
True
)
...
...
@@ -750,7 +750,7 @@ def test_view1(idtype):
assert
set
(
F
.
asnumpy
(
v
).
tolist
())
==
set
(
succ
)
assert
g
.
out_degrees
(
0
)
==
len
(
succ
)
#
edge_id &
edge_ids
# edge_ids
for
i
,
(
src
,
dst
)
in
enumerate
(
zip
(
srcs
,
dsts
)):
assert
g
.
edge_ids
(
src
,
dst
,
etype
=
etype
)
==
i
_
,
_
,
eid
=
g
.
edge_ids
(
src
,
dst
,
etype
=
etype
,
return_uv
=
True
)
...
...
tests/compute/test_kernel.py
View file @
61139302
...
...
@@ -93,10 +93,10 @@ def test_copy_src_reduce():
with
F
.
record_grad
():
if
partial
:
g
.
pull
(
nid
,
fn
.
copy_
src
(
src
=
'u'
,
out
=
'm'
),
g
.
pull
(
nid
,
fn
.
copy_
u
(
u
=
'u'
,
out
=
'm'
),
builtin
[
red
](
msg
=
'm'
,
out
=
'r1'
))
else
:
g
.
update_all
(
fn
.
copy_
src
(
src
=
'u'
,
out
=
'm'
),
g
.
update_all
(
fn
.
copy_
u
(
u
=
'u'
,
out
=
'm'
),
builtin
[
red
](
msg
=
'm'
,
out
=
'r1'
))
r1
=
g
.
ndata
[
'r1'
]
F
.
backward
(
F
.
reduce_sum
(
r1
))
...
...
@@ -155,10 +155,10 @@ def test_copy_edge_reduce():
with
F
.
record_grad
():
if
partial
:
g
.
pull
(
nid
,
fn
.
copy_e
dge
(
edg
e
=
'e'
,
out
=
'm'
),
g
.
pull
(
nid
,
fn
.
copy_e
(
e
=
'e'
,
out
=
'm'
),
builtin
[
red
](
msg
=
'm'
,
out
=
'r1'
))
else
:
g
.
update_all
(
fn
.
copy_e
dge
(
edg
e
=
'e'
,
out
=
'm'
),
g
.
update_all
(
fn
.
copy_e
(
e
=
'e'
,
out
=
'm'
),
builtin
[
red
](
msg
=
'm'
,
out
=
'r1'
))
r1
=
g
.
ndata
[
'r1'
]
F
.
backward
(
F
.
reduce_sum
(
r1
))
...
...
@@ -339,14 +339,14 @@ def test_all_binary_builtins():
# NOTE(zihao): add self-loop to avoid zero-degree nodes.
g
.
add_edges
(
g
.
nodes
(),
g
.
nodes
())
for
i
in
range
(
2
,
18
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
1
,
i
)
g
.
add_edge
(
i
,
18
)
g
.
add_edge
(
i
,
19
)
g
.
add_edge
(
18
,
0
)
g
.
add_edge
(
18
,
1
)
g
.
add_edge
(
19
,
0
)
g
.
add_edge
(
19
,
1
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
s
(
1
,
i
)
g
.
add_edge
s
(
i
,
18
)
g
.
add_edge
s
(
i
,
19
)
g
.
add_edge
s
(
18
,
0
)
g
.
add_edge
s
(
18
,
1
)
g
.
add_edge
s
(
19
,
0
)
g
.
add_edge
s
(
19
,
1
)
g
=
g
.
to
(
F
.
ctx
())
nid
=
F
.
tensor
([
0
,
1
,
4
,
5
,
7
,
12
,
14
,
15
,
18
,
19
],
g
.
idtype
)
target
=
[
"u"
,
"v"
,
"e"
]
...
...
tests/compute/test_pickle.py
View file @
61139302
...
...
@@ -14,7 +14,6 @@ from test_utils import parametrize_idtype, get_cases
from
utils
import
assert_is_identical
,
assert_is_identical_hetero
def
_assert_is_identical_nodeflow
(
nf1
,
nf2
):
assert
nf1
.
is_readonly
==
nf2
.
is_readonly
assert
nf1
.
number_of_nodes
()
==
nf2
.
number_of_nodes
()
src
,
dst
=
nf1
.
all_edges
()
src2
,
dst2
=
nf2
.
all_edges
()
...
...
tests/compute/test_propagate.py
View file @
61139302
...
...
@@ -92,10 +92,10 @@ def test_prop_nodes_topo(idtype):
# tree
tree
=
dgl
.
DGLGraph
()
tree
.
add_nodes
(
5
)
tree
.
add_edge
(
1
,
0
)
tree
.
add_edge
(
2
,
0
)
tree
.
add_edge
(
3
,
2
)
tree
.
add_edge
(
4
,
2
)
tree
.
add_edge
s
(
1
,
0
)
tree
.
add_edge
s
(
2
,
0
)
tree
.
add_edge
s
(
3
,
2
)
tree
.
add_edge
s
(
4
,
2
)
tree
=
dgl
.
graph
(
tree
.
edges
())
# init node feature data
tree
.
ndata
[
"x"
]
=
F
.
zeros
((
5
,
2
))
...
...
tests/compute/test_removal.py
View file @
61139302
...
...
@@ -10,7 +10,7 @@ def test_node_removal(idtype):
g
=
dgl
.
DGLGraph
()
g
=
g
.
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
10
)
g
.
add_edge
(
0
,
0
)
g
.
add_edge
s
(
0
,
0
)
assert
g
.
number_of_nodes
()
==
10
g
.
ndata
[
"id"
]
=
F
.
arange
(
0
,
10
)
...
...
@@ -42,8 +42,8 @@ def test_multigraph_node_removal(idtype):
g
=
g
.
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
5
)
for
i
in
range
(
5
):
g
.
add_edge
(
i
,
i
)
g
.
add_edge
(
i
,
i
)
g
.
add_edge
s
(
i
,
i
)
g
.
add_edge
s
(
i
,
i
)
assert
g
.
number_of_nodes
()
==
5
assert
g
.
number_of_edges
()
==
10
...
...
@@ -54,8 +54,8 @@ def test_multigraph_node_removal(idtype):
# add nodes
g
.
add_nodes
(
1
)
g
.
add_edge
(
1
,
1
)
g
.
add_edge
(
1
,
1
)
g
.
add_edge
s
(
1
,
1
)
g
.
add_edge
s
(
1
,
1
)
assert
g
.
number_of_nodes
()
==
4
assert
g
.
number_of_edges
()
==
8
...
...
@@ -71,8 +71,8 @@ def test_multigraph_edge_removal(idtype):
g
=
g
.
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
5
)
for
i
in
range
(
5
):
g
.
add_edge
(
i
,
i
)
g
.
add_edge
(
i
,
i
)
g
.
add_edge
s
(
i
,
i
)
g
.
add_edge
s
(
i
,
i
)
assert
g
.
number_of_nodes
()
==
5
assert
g
.
number_of_edges
()
==
10
...
...
@@ -82,8 +82,8 @@ def test_multigraph_edge_removal(idtype):
assert
g
.
number_of_edges
()
==
8
# add edges
g
.
add_edge
(
1
,
1
)
g
.
add_edge
(
1
,
1
)
g
.
add_edge
s
(
1
,
1
)
g
.
add_edge
s
(
1
,
1
)
assert
g
.
number_of_nodes
()
==
5
assert
g
.
number_of_edges
()
==
10
...
...
@@ -100,7 +100,7 @@ def test_edge_removal(idtype):
g
.
add_nodes
(
5
)
for
i
in
range
(
5
):
for
j
in
range
(
5
):
g
.
add_edge
(
i
,
j
)
g
.
add_edge
s
(
i
,
j
)
g
.
edata
[
"id"
]
=
F
.
arange
(
0
,
25
)
# remove edges
...
...
@@ -114,7 +114,7 @@ def test_edge_removal(idtype):
assert
dgl
.
EID
not
in
g
.
edata
# add edges
g
.
add_edge
(
3
,
3
)
g
.
add_edge
s
(
3
,
3
)
assert
g
.
number_of_nodes
()
==
5
assert
g
.
number_of_edges
()
==
19
assert
F
.
array_equal
(
...
...
@@ -138,7 +138,7 @@ def test_node_and_edge_removal(idtype):
g
.
add_nodes
(
10
)
for
i
in
range
(
10
):
for
j
in
range
(
10
):
g
.
add_edge
(
i
,
j
)
g
.
add_edge
s
(
i
,
j
)
g
.
edata
[
"id"
]
=
F
.
arange
(
0
,
100
)
assert
g
.
number_of_nodes
()
==
10
assert
g
.
number_of_edges
()
==
100
...
...
@@ -161,7 +161,7 @@ def test_node_and_edge_removal(idtype):
# add edges
for
i
in
range
(
8
,
10
):
for
j
in
range
(
8
,
10
):
g
.
add_edge
(
i
,
j
)
g
.
add_edge
s
(
i
,
j
)
assert
g
.
number_of_nodes
()
==
10
assert
g
.
number_of_edges
()
==
58
...
...
tests/compute/test_sampler.py
View file @
61139302
...
...
@@ -375,7 +375,7 @@ def check_negative_sampler(mode, exclude_positive, neg_size):
for
i
in
range
(
len
(
neg_eid
)):
u
,
v
=
F
.
asnumpy
(
neg_src
[
i
]),
F
.
asnumpy
(
neg_dst
[
i
])
if
g
.
has_edge_between
(
u
,
v
):
eid
=
g
.
edge_id
(
u
,
v
)
eid
=
g
.
edge_id
s
(
u
,
v
)
etype
=
g
.
edata
[
'etype'
][
eid
]
exist
=
neg_edges
.
edata
[
'etype'
][
i
]
==
etype
assert
F
.
asnumpy
(
exists
[
i
])
==
F
.
asnumpy
(
exist
)
...
...
@@ -461,7 +461,7 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
for
i
in
range
(
len
(
neg_eid
)):
u
,
v
=
F
.
asnumpy
(
neg_src
[
i
]),
F
.
asnumpy
(
neg_dst
[
i
])
if
g
.
has_edge_between
(
u
,
v
):
eid
=
g
.
edge_id
(
u
,
v
)
eid
=
g
.
edge_id
s
(
u
,
v
)
etype
=
g
.
edata
[
'etype'
][
eid
]
exist
=
neg_edges
.
edata
[
'etype'
][
i
]
==
etype
assert
F
.
asnumpy
(
exists
[
i
])
==
F
.
asnumpy
(
exist
)
...
...
@@ -488,7 +488,7 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
for
i
in
range
(
len
(
neg_eid
)):
u
,
v
=
F
.
asnumpy
(
neg_src
[
i
]),
F
.
asnumpy
(
neg_dst
[
i
])
if
g
.
has_edge_between
(
u
,
v
):
eid
=
g
.
edge_id
(
u
,
v
)
eid
=
g
.
edge_id
s
(
u
,
v
)
etype
=
g
.
edata
[
'etype'
][
eid
]
exist
=
neg_edges
.
edata
[
'etype'
][
i
]
==
etype
assert
F
.
asnumpy
(
exists
[
i
])
==
F
.
asnumpy
(
exist
)
...
...
tests/compute/test_sampling.py
View file @
61139302
...
...
@@ -14,7 +14,7 @@ def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
for
i
in
range
(
traces
.
shape
[
0
]):
for
j
in
range
(
traces
.
shape
[
1
]
-
1
):
assert
g
.
has_edge_between
(
assert
g
.
has_edge
s
_between
(
traces
[
i
,
j
],
traces
[
i
,
j
+
1
],
etype
=
metapath
[
j
])
if
prob
is
not
None
and
prob
in
g
.
edges
[
metapath
[
j
]].
data
:
p
=
F
.
asnumpy
(
g
.
edges
[
metapath
[
j
]].
data
[
'p'
])
...
...
tests/compute/test_shared_mem.py
View file @
61139302
...
...
@@ -23,7 +23,6 @@ def create_test_graph(idtype):
return
g
def
_assert_is_identical_hetero
(
g
,
g2
):
assert
g
.
is_readonly
==
g2
.
is_readonly
assert
g
.
ntypes
==
g2
.
ntypes
assert
g
.
canonical_etypes
==
g2
.
canonical_etypes
...
...
tests/compute/test_specialization.py
View file @
61139302
...
...
@@ -13,10 +13,10 @@ def generate_graph(idtype):
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
s
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
g
.
add_edge
s
(
9
,
0
)
g
.
ndata
.
update
({
'f1'
:
F
.
randn
((
10
,)),
'f2'
:
F
.
randn
((
10
,
D
))})
weights
=
F
.
randn
((
17
,))
g
.
edata
.
update
({
'e1'
:
weights
,
'e2'
:
F
.
unsqueeze
(
weights
,
1
)})
...
...
@@ -42,7 +42,7 @@ def test_v2v_update_all(idtype):
g
=
generate_graph
(
idtype
)
# update all
v1
=
g
.
ndata
[
fld
]
g
.
update_all
(
fn
.
copy_
src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
g
.
update_all
(
fn
.
copy_
u
(
u
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
ndata
.
update
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
...
...
@@ -50,7 +50,7 @@ def test_v2v_update_all(idtype):
assert
F
.
allclose
(
v2
,
v3
)
# update all with edge weights
v1
=
g
.
ndata
[
fld
]
g
.
update_all
(
fn
.
src
_mul_e
dge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
g
.
update_all
(
fn
.
u
_mul_e
(
fld
,
'e1'
,
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
ndata
.
update
({
fld
:
v1
})
...
...
@@ -84,7 +84,7 @@ def test_v2v_snr(idtype):
g
=
generate_graph
(
idtype
)
# send and recv
v1
=
g
.
ndata
[
fld
]
g
.
send_and_recv
((
u
,
v
),
fn
.
copy_
src
(
src
=
fld
,
out
=
'm'
),
g
.
send_and_recv
((
u
,
v
),
fn
.
copy_
u
(
u
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
ndata
.
update
({
fld
:
v1
})
...
...
@@ -93,7 +93,7 @@ def test_v2v_snr(idtype):
assert
F
.
allclose
(
v2
,
v3
)
# send and recv with edge weights
v1
=
g
.
ndata
[
fld
]
g
.
send_and_recv
((
u
,
v
),
fn
.
src
_mul_e
dge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
g
.
send_and_recv
((
u
,
v
),
fn
.
u
_mul_e
(
fld
,
'e1'
,
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
ndata
.
update
({
fld
:
v1
})
...
...
@@ -127,7 +127,7 @@ def test_v2v_pull(idtype):
g
=
generate_graph
(
idtype
)
# send and recv
v1
=
g
.
ndata
[
fld
]
g
.
pull
(
nodes
,
fn
.
copy_
src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
g
.
pull
(
nodes
,
fn
.
copy_
u
(
u
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
ndata
[
fld
]
=
v1
g
.
pull
(
nodes
,
message_func
,
reduce_func
,
apply_func
)
...
...
@@ -135,7 +135,7 @@ def test_v2v_pull(idtype):
assert
F
.
allclose
(
v2
,
v3
)
# send and recv with edge weights
v1
=
g
.
ndata
[
fld
]
g
.
pull
(
nodes
,
fn
.
src
_mul_e
dge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
g
.
pull
(
nodes
,
fn
.
u
_mul_e
(
fld
,
'e1'
,
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
ndata
[
fld
]
=
v1
...
...
@@ -154,8 +154,8 @@ def test_update_all_multi_fallback(idtype):
g
=
g
.
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
10
)
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
s
(
i
,
9
)
g
.
ndata
[
'h'
]
=
F
.
randn
((
10
,
D
))
g
.
edata
[
'w1'
]
=
F
.
randn
((
16
,))
g
.
edata
[
'w2'
]
=
F
.
randn
((
16
,
D
))
...
...
@@ -183,12 +183,12 @@ def test_update_all_multi_fallback(idtype):
g
.
update_all
(
_mfunc_hxw1
,
_rfunc_m1max
,
_afunc
)
o3
=
g
.
ndata
.
pop
(
'o3'
)
# v2v spmv
g
.
update_all
(
fn
.
src
_mul_e
dge
(
src
=
'h'
,
edge
=
'w1'
,
out
=
'm1'
),
g
.
update_all
(
fn
.
u
_mul_e
(
'h'
,
'w1'
,
'm1'
),
fn
.
sum
(
msg
=
'm1'
,
out
=
'o1'
),
_afunc
)
assert
F
.
allclose
(
o1
,
g
.
ndata
.
pop
(
'o1'
))
# v2v fallback to e2v
g
.
update_all
(
fn
.
src
_mul_e
dge
(
src
=
'h'
,
edge
=
'w2'
,
out
=
'm2'
),
g
.
update_all
(
fn
.
u
_mul_e
(
'h'
,
'w2'
,
'm2'
),
fn
.
sum
(
msg
=
'm2'
,
out
=
'o2'
),
_afunc
)
assert
F
.
allclose
(
o2
,
g
.
ndata
.
pop
(
'o2'
))
...
...
@@ -200,8 +200,8 @@ def test_pull_multi_fallback(idtype):
g
=
g
.
astype
(
idtype
).
to
(
F
.
ctx
())
g
.
add_nodes
(
10
)
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
s
(
i
,
9
)
g
.
ndata
[
'h'
]
=
F
.
randn
((
10
,
D
))
g
.
edata
[
'w1'
]
=
F
.
randn
((
16
,))
g
.
edata
[
'w2'
]
=
F
.
randn
((
16
,
D
))
...
...
@@ -231,12 +231,12 @@ def test_pull_multi_fallback(idtype):
g
.
pull
(
nodes
,
_mfunc_hxw1
,
_rfunc_m1max
,
_afunc
)
o3
=
g
.
ndata
.
pop
(
'o3'
)
# v2v spmv
g
.
pull
(
nodes
,
fn
.
src
_mul_e
dge
(
src
=
'h'
,
edge
=
'w1'
,
out
=
'm1'
),
g
.
pull
(
nodes
,
fn
.
u
_mul_e
(
'h'
,
'w1'
,
'm1'
),
fn
.
sum
(
msg
=
'm1'
,
out
=
'o1'
),
_afunc
)
assert
F
.
allclose
(
o1
,
g
.
ndata
.
pop
(
'o1'
))
# v2v fallback to e2v
g
.
pull
(
nodes
,
fn
.
src
_mul_e
dge
(
src
=
'h'
,
edge
=
'w2'
,
out
=
'm2'
),
g
.
pull
(
nodes
,
fn
.
u
_mul_e
(
'h'
,
'w2'
,
'm2'
),
fn
.
sum
(
msg
=
'm2'
,
out
=
'o2'
),
_afunc
)
assert
F
.
allclose
(
o2
,
g
.
ndata
.
pop
(
'o2'
))
...
...
@@ -268,7 +268,7 @@ def test_spmv_3d_feat(idtype):
g
.
ndata
[
'h'
]
=
h
g
.
edata
[
'h'
]
=
e
g
.
update_all
(
message_func
=
fn
.
src
_mul_e
dge
(
'h'
,
'h'
,
'sum'
),
reduce_func
=
fn
.
sum
(
'sum'
,
'h'
))
# 1
g
.
update_all
(
message_func
=
fn
.
u
_mul_e
(
'h'
,
'h'
,
'sum'
),
reduce_func
=
fn
.
sum
(
'sum'
,
'h'
))
# 1
ans
=
g
.
ndata
[
'h'
]
g
.
ndata
[
'h'
]
=
h
...
...
@@ -290,7 +290,7 @@ def test_spmv_3d_feat(idtype):
g
.
ndata
[
'h'
]
=
h
g
.
edata
[
'h'
]
=
e
g
.
update_all
(
message_func
=
fn
.
src
_mul_e
dge
(
'h'
,
'h'
,
'sum'
),
reduce_func
=
fn
.
sum
(
'sum'
,
'h'
))
# 1
g
.
update_all
(
message_func
=
fn
.
u
_mul_e
(
'h'
,
'h'
,
'sum'
),
reduce_func
=
fn
.
sum
(
'sum'
,
'h'
))
# 1
ans
=
g
.
ndata
[
'h'
]
g
.
ndata
[
'h'
]
=
h
...
...
tests/compute/test_subgraph.py
View file @
61139302
...
...
@@ -17,10 +17,10 @@ def generate_graph(grad=False, add_data=True):
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
g
.
add_edge
s
(
0
,
i
)
g
.
add_edge
s
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
g
.
add_edge
s
(
9
,
0
)
if
add_data
:
ncol
=
F
.
randn
((
10
,
D
))
ecol
=
F
.
randn
((
17
,
D
))
...
...
tests/compute/test_transform.py
View file @
61139302
...
...
@@ -99,10 +99,10 @@ def test_no_backtracking():
L
=
G
.
line_graph
(
backtracking
=
False
)
assert
L
.
number_of_nodes
()
==
2
*
N
for
i
in
range
(
1
,
N
):
e1
=
G
.
edge_id
(
0
,
i
)
e2
=
G
.
edge_id
(
i
,
0
)
assert
not
L
.
has_edge_between
(
e1
,
e2
)
assert
not
L
.
has_edge_between
(
e2
,
e1
)
e1
=
G
.
edge_id
s
(
0
,
i
)
e2
=
G
.
edge_id
s
(
i
,
0
)
assert
not
L
.
has_edge
s
_between
(
e1
,
e2
)
assert
not
L
.
has_edge
s
_between
(
e2
,
e1
)
# reverse graph related
@
parametrize_idtype
...
...
@@ -122,9 +122,9 @@ def test_reverse(idtype):
assert
g
.
number_of_edges
()
==
rg
.
number_of_edges
()
assert
F
.
allclose
(
F
.
astype
(
rg
.
has_edges_between
(
[
1
,
2
,
1
],
[
0
,
1
,
2
]),
F
.
float32
),
F
.
ones
((
3
,)))
assert
g
.
edge_id
(
0
,
1
)
==
rg
.
edge_id
(
1
,
0
)
assert
g
.
edge_id
(
1
,
2
)
==
rg
.
edge_id
(
2
,
1
)
assert
g
.
edge_id
(
2
,
1
)
==
rg
.
edge_id
(
1
,
2
)
assert
g
.
edge_id
s
(
0
,
1
)
==
rg
.
edge_id
s
(
1
,
0
)
assert
g
.
edge_id
s
(
1
,
2
)
==
rg
.
edge_id
s
(
2
,
1
)
assert
g
.
edge_id
s
(
2
,
1
)
==
rg
.
edge_id
s
(
1
,
2
)
# test dgl.reverse
# test homogeneous graph
...
...
tests/compute/utils.py
View file @
61139302
...
...
@@ -11,7 +11,6 @@ def check_fail(fn, *args, **kwargs):
return
True
def
assert_is_identical
(
g
,
g2
):
assert
g
.
is_readonly
==
g2
.
is_readonly
assert
g
.
number_of_nodes
()
==
g2
.
number_of_nodes
()
src
,
dst
=
g
.
all_edges
(
order
=
'eid'
)
src2
,
dst2
=
g2
.
all_edges
(
order
=
'eid'
)
...
...
@@ -26,7 +25,6 @@ def assert_is_identical(g, g2):
assert
F
.
allclose
(
g
.
edata
[
k
],
g2
.
edata
[
k
])
def
assert_is_identical_hetero
(
g
,
g2
,
ignore_internal_data
=
False
):
assert
g
.
is_readonly
==
g2
.
is_readonly
assert
g
.
ntypes
==
g2
.
ntypes
assert
g
.
canonical_etypes
==
g2
.
canonical_etypes
...
...
tests/distributed/test_distributed_sampling.py
View file @
61139302
...
...
@@ -96,7 +96,6 @@ def check_rpc_sampling(tmpdir, num_server):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
.
readonly
()
print
(
g
.
idtype
)
num_parts
=
num_server
num_hops
=
1
...
...
@@ -128,7 +127,6 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
.
readonly
()
num_parts
=
num_server
orig_nid
,
orig_eid
=
partition_graph
(
g
,
'test_find_edges'
,
num_parts
,
tmpdir
,
...
...
@@ -225,7 +223,6 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
.
readonly
()
num_parts
=
num_server
orig_nid
,
_
=
partition_graph
(
g
,
'test_get_degrees'
,
num_parts
,
tmpdir
,
...
...
@@ -278,7 +275,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
.
readonly
()
num_parts
=
num_server
num_hops
=
1
...
...
@@ -906,7 +902,6 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
generate_ip_config
(
"rpc_ip_config.txt"
,
num_server
,
num_server
)
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
.
readonly
()
num_parts
=
num_server
orig_nid
,
orig_eid
=
partition_graph
(
g
,
'test_in_subgraph'
,
num_parts
,
tmpdir
,
...
...
tests/distributed/test_partition.py
View file @
61139302
...
...
@@ -341,8 +341,8 @@ def check_partition(
g
.
edata
[
"feats"
]
=
F
.
tensor
(
np
.
random
.
randn
(
g
.
number_of_edges
(),
10
),
F
.
float32
)
g
.
update_all
(
fn
.
copy_
src
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"h"
))
g
.
update_all
(
fn
.
copy_e
dge
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"eh"
))
g
.
update_all
(
fn
.
copy_
u
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"h"
))
g
.
update_all
(
fn
.
copy_e
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"eh"
))
num_hops
=
2
orig_nids
,
orig_eids
=
partition_graph
(
...
...
@@ -464,8 +464,8 @@ def check_partition(
g
.
edata
[
"feats"
],
part_g
.
edata
[
dgl
.
NID
]
)
part_g
.
update_all
(
fn
.
copy_
src
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"h"
))
part_g
.
update_all
(
fn
.
copy_e
dge
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"eh"
))
part_g
.
update_all
(
fn
.
copy_
u
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"h"
))
part_g
.
update_all
(
fn
.
copy_e
(
"feats"
,
"msg"
),
fn
.
sum
(
"msg"
,
"eh"
))
assert
F
.
allclose
(
F
.
gather_row
(
g
.
ndata
[
"h"
],
local_nodes
),
F
.
gather_row
(
part_g
.
ndata
[
"h"
],
llocal_nodes
),
...
...
tests/distributed/test_shared_mem_store.py
View file @
61139302
...
...
@@ -133,7 +133,7 @@ def check_compute_func(worker_id, graph_name, return_dict):
g
.
_sync_barrier
(
60
)
in_feats
=
g
.
nodes
[
0
].
data
[
'feat'
].
shape
[
1
]
# Test update all.
g
.
update_all
(
fn
.
copy_
src
(
src
=
'feat'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'preprocess'
))
g
.
update_all
(
fn
.
copy_
u
(
u
=
'feat'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'preprocess'
))
adj
=
g
.
adjacency_matrix
(
transpose
=
True
)
tmp
=
F
.
spmm
(
adj
,
g
.
nodes
[:].
data
[
'feat'
])
assert_almost_equal
(
F
.
asnumpy
(
g
.
nodes
[:].
data
[
'preprocess'
]),
F
.
asnumpy
(
tmp
))
...
...
@@ -154,12 +154,12 @@ def check_compute_func(worker_id, graph_name, return_dict):
g
.
init_ndata
(
'tmp'
,
(
g
.
number_of_nodes
(),
10
),
'float32'
)
data
=
g
.
nodes
[:].
data
[
'tmp'
]
# Test pull
g
.
pull
(
1
,
fn
.
copy_
src
(
src
=
'feat'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'tmp'
))
g
.
pull
(
1
,
fn
.
copy_
u
(
u
=
'feat'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'tmp'
))
assert_almost_equal
(
F
.
asnumpy
(
data
[
1
]),
np
.
squeeze
(
F
.
asnumpy
(
g
.
nodes
[
1
].
data
[
'preprocess'
])))
# Test send_and_recv
in_edges
=
g
.
in_edges
(
v
=
2
)
g
.
send_and_recv
(
in_edges
,
fn
.
copy_
src
(
src
=
'feat'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'tmp'
))
g
.
send_and_recv
(
in_edges
,
fn
.
copy_
u
(
u
=
'feat'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'tmp'
))
assert_almost_equal
(
F
.
asnumpy
(
data
[
2
]),
np
.
squeeze
(
F
.
asnumpy
(
g
.
nodes
[
2
].
data
[
'preprocess'
])))
g
.
destroy
()
...
...
tutorials/models/1_gnn/6_line_graph.py
View file @
61139302
...
...
@@ -372,12 +372,12 @@ def aggregate_radius(radius, g, z):
z_list
=
[]
g
.
ndata
[
'z'
]
=
z
# pulling message from 1-hop neighbourhood
g
.
update_all
(
fn
.
copy_
src
(
src
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
g
.
update_all
(
fn
.
copy_
u
(
u
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
z_list
.
append
(
g
.
ndata
[
'z'
])
for
i
in
range
(
radius
-
1
):
for
j
in
range
(
2
**
i
):
#pulling message from 2^j neighborhood
g
.
update_all
(
fn
.
copy_
src
(
src
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
g
.
update_all
(
fn
.
copy_
u
(
u
=
'z'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'z'
))
z_list
.
append
(
g
.
ndata
[
'z'
])
return
z_list
...
...
tutorials/models/2_small_graph/3_tree-lstm.py
View file @
61139302
...
...
@@ -243,7 +243,7 @@ import torch as th
trv_graph
.
ndata
[
'a'
]
=
th
.
ones
(
graph
.
number_of_nodes
(),
1
)
traversal_order
=
dgl
.
topological_nodes_generator
(
trv_graph
)
trv_graph
.
prop_nodes
(
traversal_order
,
message_func
=
fn
.
copy_
src
(
'a'
,
'a'
),
message_func
=
fn
.
copy_
u
(
'a'
,
'a'
),
reduce_func
=
fn
.
sum
(
'a'
,
'a'
))
# the following is a syntax sugar that does the same
...
...
tutorials/models/3_generative_model/5_dgmg.py
View file @
61139302
...
...
@@ -625,7 +625,7 @@ class ChooseDestAndUpdate(nn.Module):
if
not
self
.
training
:
dest
=
Categorical
(
dests_probs
).
sample
().
item
()
if
not
g
.
has_edge_between
(
src
,
dest
):
if
not
g
.
has_edge
s
_between
(
src
,
dest
):
# For undirected graphs, add edges for both directions
# so that you can perform graph propagation.
src_list
=
[
src
,
dest
]
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment