Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8b839a23
Unverified
Commit
8b839a23
authored
Aug 20, 2023
by
Andrei Ivanov
Committed by
GitHub
Aug 20, 2023
Browse files
Improving data tests. (#6144)
parent
13204383
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
44 deletions
+50
-44
python/dgl/data/citation_graph.py
python/dgl/data/citation_graph.py
+7
-4
tests/python/common/data/test_data.py
tests/python/common/data/test_data.py
+43
-40
No files found.
python/dgl/data/citation_graph.py
View file @
8b839a23
...
@@ -7,6 +7,7 @@ from __future__ import absolute_import
...
@@ -7,6 +7,7 @@ from __future__ import absolute_import
import
os
,
sys
import
os
,
sys
import
pickle
as
pkl
import
pickle
as
pkl
import
warnings
import
networkx
as
nx
import
networkx
as
nx
...
@@ -34,10 +35,12 @@ backend = os.environ.get("DGLBACKEND", "pytorch")
...
@@ -34,10 +35,12 @@ backend = os.environ.get("DGLBACKEND", "pytorch")
def
_pickle_load
(
pkl_file
):
def
_pickle_load
(
pkl_file
):
if
sys
.
version_info
>
(
3
,
0
):
with
warnings
.
catch_warnings
():
return
pkl
.
load
(
pkl_file
,
encoding
=
"latin1"
)
warnings
.
simplefilter
(
"ignore"
,
category
=
DeprecationWarning
)
else
:
if
sys
.
version_info
>
(
3
,
0
):
return
pkl
.
load
(
pkl_file
)
return
pkl
.
load
(
pkl_file
,
encoding
=
"latin1"
)
else
:
return
pkl
.
load
(
pkl_file
)
class
CitationGraphDataset
(
DGLBuiltinDataset
):
class
CitationGraphDataset
(
DGLBuiltinDataset
):
...
...
tests/python/common/data/test_data.py
View file @
8b839a23
...
@@ -4,6 +4,7 @@ import os
...
@@ -4,6 +4,7 @@ import os
import
tarfile
import
tarfile
import
tempfile
import
tempfile
import
unittest
import
unittest
import
warnings
import
backend
as
F
import
backend
as
F
...
@@ -736,54 +737,56 @@ def _test_construct_graphs_multiple():
...
@@ -736,54 +737,56 @@ def _test_construct_graphs_multiple():
assert
expect_except
assert
expect_except
def
_
test_DefaultDataParser
(
):
def
_
get_data_table
(
data_frame
):
from
dgl.data.csv_dataset_base
import
DefaultDataParser
from
dgl.data.csv_dataset_base
import
DefaultDataParser
# common csv
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
num_nodes
=
5
num_labels
=
3
data_frame
.
to_csv
(
csv_path
,
index
=
False
)
num_dims
=
2
node_id
=
np
.
arange
(
num_nodes
)
label
=
np
.
random
.
randint
(
num_labels
,
size
=
num_nodes
)
feat
=
np
.
random
.
rand
(
num_nodes
,
num_dims
)
df
=
pd
.
DataFrame
(
{
"node_id"
:
node_id
,
"label"
:
label
,
"feat"
:
[
line
.
tolist
()
for
line
in
feat
],
}
)
df
.
to_csv
(
csv_path
,
index
=
False
)
dp
=
DefaultDataParser
()
dp
=
DefaultDataParser
()
df
=
pd
.
read_csv
(
csv_path
)
df
=
pd
.
read_csv
(
csv_path
)
dt
=
dp
(
df
)
assert
np
.
array_equal
(
node_id
,
dt
[
"node_id"
])
# Intercepting the warning: "Unamed column is found. Ignored...".
assert
np
.
array_equal
(
label
,
dt
[
"label"
])
with
warnings
.
catch_warnings
():
assert
np
.
array_equal
(
feat
,
dt
[
"feat"
])
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
return
dp
(
df
)
def
_test_DefaultDataParser
():
# common csv
num_nodes
=
5
num_labels
=
3
num_dims
=
2
node_id
=
np
.
arange
(
num_nodes
)
label
=
np
.
random
.
randint
(
num_labels
,
size
=
num_nodes
)
feat
=
np
.
random
.
rand
(
num_nodes
,
num_dims
)
df
=
pd
.
DataFrame
(
{
"node_id"
:
node_id
,
"label"
:
label
,
"feat"
:
[
line
.
tolist
()
for
line
in
feat
],
}
)
dt
=
_get_data_table
(
df
)
assert
np
.
array_equal
(
node_id
,
dt
[
"node_id"
])
assert
np
.
array_equal
(
label
,
dt
[
"label"
])
assert
np
.
array_equal
(
feat
,
dt
[
"feat"
])
# string consists of non-numeric values
# string consists of non-numeric values
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
df
=
pd
.
DataFrame
({
"label"
:
[
"a"
,
"b"
,
"c"
]})
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
expect_except
=
False
df
=
pd
.
DataFrame
({
"label"
:
[
"a"
,
"b"
,
"c"
]})
try
:
df
.
to_csv
(
csv_path
,
index
=
False
)
_get_data_table
(
df
)
dp
=
DefaultDataParser
()
except
:
df
=
pd
.
read_csv
(
csv_path
)
expect_except
=
True
expect_except
=
False
assert
expect_except
try
:
dt
=
dp
(
df
)
except
:
expect_except
=
True
assert
expect_except
# csv has index column which is ignored as it's unnamed
# csv has index column which is ignored as it's unnamed
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
df
=
pd
.
DataFrame
({
"label"
:
[
1
,
2
,
3
]})
csv_path
=
os
.
path
.
join
(
test_dir
,
"nodes.csv"
)
dt
=
_get_data_table
(
df
)
df
=
pd
.
DataFrame
({
"label"
:
[
1
,
2
,
3
]})
assert
len
(
dt
)
==
1
df
.
to_csv
(
csv_path
)
dp
=
DefaultDataParser
()
df
=
pd
.
read_csv
(
csv_path
)
dt
=
dp
(
df
)
assert
len
(
dt
)
==
1
def
_test_load_yaml_with_sanity_check
():
def
_test_load_yaml_with_sanity_check
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment