Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
365d3617
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "65d83ad70c664359139870068aa8357a6553c0e6"
Unverified
Commit
365d3617
authored
Nov 29, 2019
by
Quan (Andy) Gan
Committed by
GitHub
Nov 29, 2019
Browse files
[Bug] Fix #1036 (#1037)
* fix * unit test
parent
287f387b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
4 deletions
+29
-4
python/dgl/graph.py
python/dgl/graph.py
+2
-2
python/dgl/heterograph.py
python/dgl/heterograph.py
+2
-2
tests/compute/test_basics.py
tests/compute/test_basics.py
+25
-0
No files found.
python/dgl/graph.py
View file @
365d3617
...
@@ -2208,7 +2208,7 @@ class DGLGraph(DGLBaseGraph):
...
@@ -2208,7 +2208,7 @@ class DGLGraph(DGLBaseGraph):
raise
DGLError
(
"Group_by should be either src or dst"
)
raise
DGLError
(
"Group_by should be either src or dst"
)
if
is_all
(
edges
):
if
is_all
(
edges
):
u
,
v
,
_
=
self
.
_graph
.
edges
()
u
,
v
,
_
=
self
.
_graph
.
edges
(
'eid'
)
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
()))
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
()))
elif
isinstance
(
edges
,
tuple
):
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
,
v
=
edges
...
@@ -2270,7 +2270,7 @@ class DGLGraph(DGLBaseGraph):
...
@@ -2270,7 +2270,7 @@ class DGLGraph(DGLBaseGraph):
if
is_all
(
edges
):
if
is_all
(
edges
):
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
()))
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
()))
u
,
v
,
_
=
self
.
_graph
.
edges
()
u
,
v
,
_
=
self
.
_graph
.
edges
(
'eid'
)
elif
isinstance
(
edges
,
tuple
):
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
,
v
=
edges
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
...
...
python/dgl/heterograph.py
View file @
365d3617
...
@@ -2383,7 +2383,7 @@ class DGLHeteroGraph(object):
...
@@ -2383,7 +2383,7 @@ class DGLHeteroGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
etid
=
self
.
get_etype_id
(
etype
)
stid
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
stid
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
if
is_all
(
edges
):
if
is_all
(
edges
):
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
)
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
,
'eid'
)
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
(
etype
)))
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_edges
(
etype
)))
elif
isinstance
(
edges
,
tuple
):
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
,
v
=
edges
...
@@ -2468,7 +2468,7 @@ class DGLHeteroGraph(object):
...
@@ -2468,7 +2468,7 @@ class DGLHeteroGraph(object):
if
is_all
(
edges
):
if
is_all
(
edges
):
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
_graph
.
number_of_edges
(
etid
)))
eid
=
utils
.
toindex
(
slice
(
0
,
self
.
_graph
.
number_of_edges
(
etid
)))
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
)
u
,
v
,
_
=
self
.
_graph
.
edges
(
etid
,
'eid'
)
elif
isinstance
(
edges
,
tuple
):
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
,
v
=
edges
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
...
...
tests/compute/test_basics.py
View file @
365d3617
import
backend
as
F
import
backend
as
F
import
dgl
import
dgl
import
numpy
as
np
import
scipy.sparse
as
ssp
import
networkx
as
nx
import
networkx
as
nx
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
collections
import
defaultdict
as
ddict
from
collections
import
defaultdict
as
ddict
...
@@ -654,6 +656,28 @@ def test_group_apply_edges():
...
@@ -654,6 +656,28 @@ def test_group_apply_edges():
# test group by destination nodes
# test group by destination nodes
_test
(
'dst'
)
_test
(
'dst'
)
# GitHub issue #1036
def
test_group_apply_edges2
():
m
=
ssp
.
random
(
10
,
10
,
0.2
)
g
=
DGLGraph
(
m
,
readonly
=
True
)
g
.
ndata
[
'deg'
]
=
g
.
in_degrees
()
g
.
ndata
[
'id'
]
=
F
.
arange
(
0
,
g
.
number_of_nodes
())
g
.
edata
[
'id'
]
=
F
.
arange
(
0
,
g
.
number_of_edges
())
def
apply
(
edges
):
w
=
edges
.
data
[
'id'
]
n_nodes
,
deg
=
w
.
shape
dst
=
edges
.
dst
[
'id'
][:,
0
]
eid1
=
F
.
asnumpy
(
g
.
in_edges
(
dst
,
'eid'
)).
reshape
(
n_nodes
,
deg
).
sort
(
1
)
eid2
=
F
.
asnumpy
(
edges
.
data
[
'id'
]).
sort
(
1
)
assert
np
.
array_equal
(
eid1
,
eid2
)
return
{
'id2'
:
w
}
g
.
group_apply_edges
(
'dst'
,
apply
,
inplace
=
True
)
def
test_local_var
():
def
test_local_var
():
g
=
DGLGraph
(
nx
.
path_graph
(
5
))
g
=
DGLGraph
(
nx
.
path_graph
(
5
))
g
.
ndata
[
'h'
]
=
F
.
zeros
((
g
.
number_of_nodes
(),
3
))
g
.
ndata
[
'h'
]
=
F
.
zeros
((
g
.
number_of_nodes
(),
3
))
...
@@ -803,5 +827,6 @@ if __name__ == '__main__':
...
@@ -803,5 +827,6 @@ if __name__ == '__main__':
test_dynamic_addition
()
test_dynamic_addition
()
test_repr
()
test_repr
()
test_group_apply_edges
()
test_group_apply_edges
()
test_group_apply_edges2
()
test_local_var
()
test_local_var
()
test_local_scope
()
test_local_scope
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment