Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8b839a23
Unverified
Commit
8b839a23
authored
Aug 20, 2023
by
Andrei Ivanov
Committed by
GitHub
Aug 20, 2023
Browse files
Improving data tests. (#6144)
parent
13204383
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
44 deletions
+50
-44
python/dgl/data/citation_graph.py
python/dgl/data/citation_graph.py
+7
-4
tests/python/common/data/test_data.py
tests/python/common/data/test_data.py
+43
-40
No files found.
python/dgl/data/citation_graph.py
View file @
8b839a23
...
...
@@ -7,6 +7,7 @@ from __future__ import absolute_import
import
os
,
sys
import
pickle
as
pkl
import
warnings
import
networkx
as
nx
...
...
@@ -34,10 +35,12 @@ backend = os.environ.get("DGLBACKEND", "pytorch")
def
_pickle_load
(
pkl_file
):
if
sys
.
version_info
>
(
3
,
0
):
return
pkl
.
load
(
pkl_file
,
encoding
=
"latin1"
)
else
:
return
pkl
.
load
(
pkl_file
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
DeprecationWarning
)
if
sys
.
version_info
>
(
3
,
0
):
return
pkl
.
load
(
pkl_file
,
encoding
=
"latin1"
)
else
:
return
pkl
.
load
(
pkl_file
)
class
CitationGraphDataset
(
DGLBuiltinDataset
):
...
...
tests/python/common/data/test_data.py
View file @
8b839a23
...
...
@@ -4,6 +4,7 @@ import os
import
tarfile
import
tempfile
import
unittest
import
warnings
import
backend
as
F
...
...
@@ -736,54 +737,56 @@ def _test_construct_graphs_multiple():
assert
expect_except
def
_
test_DefaultDataParser
(
):
def
_
get_data_table
(
data_frame
):
from
dgl.data.csv_dataset_base
import
DefaultDataParser
# common csv
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
num_nodes
=
5
num_labels
=
3
num_dims
=
2
node_id
=
np
.
arange
(
num_nodes
)
label
=
np
.
random
.
randint
(
num_labels
,
size
=
num_nodes
)
feat
=
np
.
random
.
rand
(
num_nodes
,
num_dims
)
df
=
pd
.
DataFrame
(
{
"node_id"
:
node_id
,
"label"
:
label
,
"feat"
:
[
line
.
tolist
()
for
line
in
feat
],
}
)
df
.
to_csv
(
csv_path
,
index
=
False
)
data_frame
.
to_csv
(
csv_path
,
index
=
False
)
dp
=
DefaultDataParser
()
df
=
pd
.
read_csv
(
csv_path
)
dt
=
dp
(
df
)
assert
np
.
array_equal
(
node_id
,
dt
[
"node_id"
])
assert
np
.
array_equal
(
label
,
dt
[
"label"
])
assert
np
.
array_equal
(
feat
,
dt
[
"feat"
])
# Intercepting the warning: "Unamed column is found. Ignored...".
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
return
dp
(
df
)
def
_test_DefaultDataParser
():
# common csv
num_nodes
=
5
num_labels
=
3
num_dims
=
2
node_id
=
np
.
arange
(
num_nodes
)
label
=
np
.
random
.
randint
(
num_labels
,
size
=
num_nodes
)
feat
=
np
.
random
.
rand
(
num_nodes
,
num_dims
)
df
=
pd
.
DataFrame
(
{
"node_id"
:
node_id
,
"label"
:
label
,
"feat"
:
[
line
.
tolist
()
for
line
in
feat
],
}
)
dt
=
_get_data_table
(
df
)
assert
np
.
array_equal
(
node_id
,
dt
[
"node_id"
])
assert
np
.
array_equal
(
label
,
dt
[
"label"
])
assert
np
.
array_equal
(
feat
,
dt
[
"feat"
])
# string consists of non-numeric values
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
df
=
pd
.
DataFrame
({
"label"
:
[
"a"
,
"b"
,
"c"
]})
df
.
to_csv
(
csv_path
,
index
=
False
)
dp
=
DefaultDataParser
()
df
=
pd
.
read_csv
(
csv_path
)
expect_except
=
False
try
:
dt
=
dp
(
df
)
except
:
expect_except
=
True
assert
expect_except
df
=
pd
.
DataFrame
({
"label"
:
[
"a"
,
"b"
,
"c"
]})
expect_except
=
False
try
:
_get_data_table
(
df
)
except
:
expect_except
=
True
assert
expect_except
# csv has index column which is ignored as it's unnamed
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
df
=
pd
.
DataFrame
({
"label"
:
[
1
,
2
,
3
]})
df
.
to_csv
(
csv_path
)
dp
=
DefaultDataParser
()
df
=
pd
.
read_csv
(
csv_path
)
dt
=
dp
(
df
)
assert
len
(
dt
)
==
1
df
=
pd
.
DataFrame
({
"label"
:
[
1
,
2
,
3
]})
dt
=
_get_data_table
(
df
)
assert
len
(
dt
)
==
1
def
_test_load_yaml_with_sanity_check
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment