Unverified Commit 8b839a23 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving data tests. (#6144)

parent 13204383
......@@ -7,6 +7,7 @@ from __future__ import absolute_import
import os, sys
import pickle as pkl
import warnings
import networkx as nx
......@@ -34,6 +35,8 @@ backend = os.environ.get("DGLBACKEND", "pytorch")
def _pickle_load(pkl_file):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
if sys.version_info > (3, 0):
return pkl.load(pkl_file, encoding="latin1")
else:
......
......@@ -4,6 +4,7 @@ import os
import tarfile
import tempfile
import unittest
import warnings
import backend as F
......@@ -736,12 +737,24 @@ def _test_construct_graphs_multiple():
assert expect_except
def _test_DefaultDataParser():
def _get_data_table(data_frame):
from dgl.data.csv_dataset_base import DefaultDataParser
# common csv
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
data_frame.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
# Intercepting the warning: "Unamed column is found. Ignored...".
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return dp(df)
def _test_DefaultDataParser():
# common csv
num_nodes = 5
num_labels = 3
num_dims = 2
......@@ -755,34 +768,24 @@ def _test_DefaultDataParser():
"feat": [line.tolist() for line in feat],
}
)
df.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
dt = _get_data_table(df)
assert np.array_equal(node_id, dt["node_id"])
assert np.array_equal(label, dt["label"])
assert np.array_equal(feat, dt["feat"])
# string consists of non-numeric values
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
df = pd.DataFrame({"label": ["a", "b", "c"]})
df.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
expect_except = False
try:
dt = dp(df)
_get_data_table(df)
except:
expect_except = True
assert expect_except
# csv has index column which is ignored as it's unnamed
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
df = pd.DataFrame({"label": [1, 2, 3]})
df.to_csv(csv_path)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
dt = _get_data_table(df)
assert len(dt) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment