Unverified Commit 8b839a23 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving data tests. (#6144)

parent 13204383
...@@ -7,6 +7,7 @@ from __future__ import absolute_import ...@@ -7,6 +7,7 @@ from __future__ import absolute_import
import os, sys import os, sys
import pickle as pkl import pickle as pkl
import warnings
import networkx as nx import networkx as nx
...@@ -34,6 +35,8 @@ backend = os.environ.get("DGLBACKEND", "pytorch") ...@@ -34,6 +35,8 @@ backend = os.environ.get("DGLBACKEND", "pytorch")
def _pickle_load(pkl_file): def _pickle_load(pkl_file):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
if sys.version_info > (3, 0): if sys.version_info > (3, 0):
return pkl.load(pkl_file, encoding="latin1") return pkl.load(pkl_file, encoding="latin1")
else: else:
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import tarfile import tarfile
import tempfile import tempfile
import unittest import unittest
import warnings
import backend as F import backend as F
...@@ -736,12 +737,24 @@ def _test_construct_graphs_multiple(): ...@@ -736,12 +737,24 @@ def _test_construct_graphs_multiple():
assert expect_except assert expect_except
def _test_DefaultDataParser(): def _get_data_table(data_frame):
from dgl.data.csv_dataset_base import DefaultDataParser from dgl.data.csv_dataset_base import DefaultDataParser
# common csv
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv") csv_path = os.path.join(test_dir, "nodes.csv")
data_frame.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
# Intercepting the warning: "Unamed column is found. Ignored...".
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return dp(df)
def _test_DefaultDataParser():
# common csv
num_nodes = 5 num_nodes = 5
num_labels = 3 num_labels = 3
num_dims = 2 num_dims = 2
...@@ -755,34 +768,24 @@ def _test_DefaultDataParser(): ...@@ -755,34 +768,24 @@ def _test_DefaultDataParser():
"feat": [line.tolist() for line in feat], "feat": [line.tolist() for line in feat],
} }
) )
df.to_csv(csv_path, index=False)
dp = DefaultDataParser() dt = _get_data_table(df)
df = pd.read_csv(csv_path)
dt = dp(df)
assert np.array_equal(node_id, dt["node_id"]) assert np.array_equal(node_id, dt["node_id"])
assert np.array_equal(label, dt["label"]) assert np.array_equal(label, dt["label"])
assert np.array_equal(feat, dt["feat"]) assert np.array_equal(feat, dt["feat"])
# string consists of non-numeric values # string consists of non-numeric values
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
df = pd.DataFrame({"label": ["a", "b", "c"]}) df = pd.DataFrame({"label": ["a", "b", "c"]})
df.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
expect_except = False expect_except = False
try: try:
dt = dp(df) _get_data_table(df)
except: except:
expect_except = True expect_except = True
assert expect_except assert expect_except
# csv has index column which is ignored as it's unnamed # csv has index column which is ignored as it's unnamed
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
df = pd.DataFrame({"label": [1, 2, 3]}) df = pd.DataFrame({"label": [1, 2, 3]})
df.to_csv(csv_path) dt = _get_data_table(df)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
assert len(dt) == 1 assert len(dt) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment