Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
13204383
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "177dc133d685d69a89a775eb9d1ca720094e14c7"
Unverified
Commit
13204383
authored
Aug 20, 2023
by
Andrei Ivanov
Committed by
GitHub
Aug 20, 2023
Browse files
Improving basic tests. (#6145)
parent
44f4b0e2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
8 deletions
+17
-8
tests/python/common/function/test_basics.py
tests/python/common/function/test_basics.py
+17
-8
No files found.
tests/python/common/function/test_basics.py
View file @
13204383
import
unittest
import
warnings
from
collections
import
defaultdict
as
ddict
import
backend
as
F
...
...
@@ -6,8 +6,6 @@ import backend as F
import
dgl
import
networkx
as
nx
import
numpy
as
np
import
scipy.sparse
as
ssp
from
dgl
import
DGLGraph
from
utils
import
parametrize_idtype
D
=
5
...
...
@@ -33,7 +31,7 @@ def apply_node_func(nodes):
def
generate_graph_old
(
grad
=
False
):
g
=
DGLG
raph
()
g
=
dgl
.
g
raph
(
[]
)
g
.
add_nodes
(
10
)
# 10 nodes
# create a graph where 0 is the source and 9 is the sink
# 17 edges
...
...
@@ -419,7 +417,14 @@ def test_update_all_0deg(idtype):
# test#2: graph with no edge
g
=
dgl
.
graph
(([],
[]),
num_nodes
=
5
,
idtype
=
idtype
,
device
=
F
.
ctx
())
g
.
ndata
[
"h"
]
=
old_repr
g
.
update_all
(
_message
,
_reduce
,
lambda
nodes
:
{
"h"
:
nodes
.
data
[
"h"
]
*
2
})
# Intercepting the warning: The input graph for the user-defined edge
# function does not contain valid edges.
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
g
.
update_all
(
_message
,
_reduce
,
lambda
nodes
:
{
"h"
:
nodes
.
data
[
"h"
]
*
2
}
)
new_repr
=
g
.
ndata
[
"h"
]
# should fallback to apply
assert
F
.
allclose
(
new_repr
,
2
*
old_repr
)
...
...
@@ -455,7 +460,12 @@ def test_pull_0deg(idtype):
# test#2: pull only 0deg node
old
=
F
.
randn
((
2
,
5
))
g
.
ndata
[
"h"
]
=
old
g
.
pull
(
0
,
_message
,
_reduce
,
lambda
nodes
:
{
"h"
:
nodes
.
data
[
"h"
]
*
2
})
# Intercepting the warning: The input graph for the user-defined edge
# function does not contain valid edges
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
g
.
pull
(
0
,
_message
,
_reduce
,
lambda
nodes
:
{
"h"
:
nodes
.
data
[
"h"
]
*
2
})
new
=
g
.
ndata
[
"h"
]
# 0deg check: fallback to apply
assert
F
.
allclose
(
new
[
0
],
2
*
old
[
0
])
...
...
@@ -467,8 +477,7 @@ def test_dynamic_addition():
N
=
3
D
=
1
g
=
DGLGraph
()
g
=
g
.
to
(
F
.
ctx
())
g
=
dgl
.
graph
([]).
to
(
F
.
ctx
())
# Test node addition
g
.
add_nodes
(
N
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment