Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
365d3617
Unverified
Commit
365d3617
authored
Nov 29, 2019
by
Quan (Andy) Gan
Committed by
GitHub
Nov 29, 2019
Browse files
[Bug] Fix #1036 (#1037)
* fix * unit test
parent
287f387b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
4 deletions
+29
-4
python/dgl/graph.py
python/dgl/graph.py
+2
-2
python/dgl/heterograph.py
python/dgl/heterograph.py
+2
-2
tests/compute/test_basics.py
tests/compute/test_basics.py
+25
-0
No files found.
python/dgl/graph.py
View file @
365d3617
...
...
@@ -2208,7 +2208,7 @@ class DGLGraph(DGLBaseGraph):
raise
DGLError
(
"Group_by should be either src or dst"
)
if
is_all
(
edges
):
u
,
v
,
_
=
self
.
_graph
.
edges
()
u
,
v
,
_
=
self
.
_graph
.
edges
(
'eid'
)
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
()))
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
...
...
@@ -2270,7 +2270,7 @@ class DGLGraph(DGLBaseGraph):
if
is_all
(
edges
):
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
()))
u
,
v
,
_
=
self
.
_graph
.
edges
()
u
,
v
,
_
=
self
.
_graph
.
edges
(
'eid'
)
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
=
utils
.
toindex
(
u
)
...
...
python/dgl/heterograph.py
View file @
365d3617
...
...
@@ -2383,7 +2383,7 @@ class DGLHeteroGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
stid
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
if
is_all
(
edges
):
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
)
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
,
'eid'
)
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
(
etype
)))
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
...
...
@@ -2468,7 +2468,7 @@ class DGLHeteroGraph(object):
if
is_all
(
edges
):
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
_graph
.
number_of_edges
(
etid
)))
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
)
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
,
'eid'
)
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
=
utils
.
toindex
(
u
)
...
...
tests/compute/test_basics.py
View file @
365d3617
import
backend
as
F
import
dgl
import
numpy
as
np
import
scipy.sparse
as
ssp
import
networkx
as
nx
from
dgl
import
DGLGraph
from
collections
import
defaultdict
as
ddict
...
...
@@ -654,6 +656,28 @@ def test_group_apply_edges():
# test group by destination nodes
_test
(
'dst'
)
# GitHub issue #1036
def
test_group_apply_edges2
():
m
=
ssp
.
random
(
10
,
10
,
0.2
)
g
=
DGLGraph
(
m
,
readonly
=
True
)
g
.
ndata
[
'deg'
]
=
g
.
in_degrees
()
g
.
ndata
[
'id'
]
=
F
.
arange
(
0
,
g
.
number_of_nodes
())
g
.
edata
[
'id'
]
=
F
.
arange
(
0
,
g
.
number_of_edges
())
def
apply
(
edges
):
w
=
edges
.
data
[
'id'
]
n_nodes
,
deg
=
w
.
shape
dst
=
edges
.
dst
[
'id'
][:,
0
]
eid1
=
F
.
asnumpy
(
g
.
in_edges
(
dst
,
'eid'
)).
reshape
(
n_nodes
,
deg
).
sort
(
1
)
eid2
=
F
.
asnumpy
(
edges
.
data
[
'id'
]).
sort
(
1
)
assert
np
.
array_equal
(
eid1
,
eid2
)
return
{
'id2'
:
w
}
g
.
group_apply_edges
(
'dst'
,
apply
,
inplace
=
True
)
def
test_local_var
():
g
=
DGLGraph
(
nx
.
path_graph
(
5
))
g
.
ndata
[
'h'
]
=
F
.
zeros
((
g
.
number_of_nodes
(),
3
))
...
...
@@ -803,5 +827,6 @@ if __name__ == '__main__':
test_dynamic_addition
()
test_repr
()
test_group_apply_edges
()
test_group_apply_edges2
()
test_local_var
()
test_local_scope
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment