Unverified Commit 8b839a23 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving data tests. (#6144)

parent 13204383
......@@ -7,6 +7,7 @@ from __future__ import absolute_import
import os, sys
import pickle as pkl
import warnings
import networkx as nx
......@@ -34,10 +35,12 @@ backend = os.environ.get("DGLBACKEND", "pytorch")
def _pickle_load(pkl_file):
if sys.version_info > (3, 0):
return pkl.load(pkl_file, encoding="latin1")
else:
return pkl.load(pkl_file)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
if sys.version_info > (3, 0):
return pkl.load(pkl_file, encoding="latin1")
else:
return pkl.load(pkl_file)
class CitationGraphDataset(DGLBuiltinDataset):
......
......@@ -4,6 +4,7 @@ import os
import tarfile
import tempfile
import unittest
import warnings
import backend as F
......@@ -736,54 +737,56 @@ def _test_construct_graphs_multiple():
assert expect_except
def _test_DefaultDataParser():
def _get_data_table(data_frame):
from dgl.data.csv_dataset_base import DefaultDataParser
# common csv
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
num_nodes = 5
num_labels = 3
num_dims = 2
node_id = np.arange(num_nodes)
label = np.random.randint(num_labels, size=num_nodes)
feat = np.random.rand(num_nodes, num_dims)
df = pd.DataFrame(
{
"node_id": node_id,
"label": label,
"feat": [line.tolist() for line in feat],
}
)
df.to_csv(csv_path, index=False)
data_frame.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
assert np.array_equal(node_id, dt["node_id"])
assert np.array_equal(label, dt["label"])
assert np.array_equal(feat, dt["feat"])
# Intercepting the warning: "Unamed column is found. Ignored...".
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return dp(df)
def _test_DefaultDataParser():
# common csv
num_nodes = 5
num_labels = 3
num_dims = 2
node_id = np.arange(num_nodes)
label = np.random.randint(num_labels, size=num_nodes)
feat = np.random.rand(num_nodes, num_dims)
df = pd.DataFrame(
{
"node_id": node_id,
"label": label,
"feat": [line.tolist() for line in feat],
}
)
dt = _get_data_table(df)
assert np.array_equal(node_id, dt["node_id"])
assert np.array_equal(label, dt["label"])
assert np.array_equal(feat, dt["feat"])
# string consists of non-numeric values
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
df = pd.DataFrame({"label": ["a", "b", "c"]})
df.to_csv(csv_path, index=False)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
expect_except = False
try:
dt = dp(df)
except:
expect_except = True
assert expect_except
df = pd.DataFrame({"label": ["a", "b", "c"]})
expect_except = False
try:
_get_data_table(df)
except:
expect_except = True
assert expect_except
# csv has index column which is ignored as it's unnamed
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
df = pd.DataFrame({"label": [1, 2, 3]})
df.to_csv(csv_path)
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
assert len(dt) == 1
df = pd.DataFrame({"label": [1, 2, 3]})
dt = _get_data_table(df)
assert len(dt) == 1
def _test_load_yaml_with_sanity_check():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment