Unverified Commit 4b295d60 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[FIX] fix compute/test_data.py::test_csvdataset (#3643)

parent 6dce19d8
...@@ -716,7 +716,7 @@ def _test_DGLCSVDataset_single(): ...@@ -716,7 +716,7 @@ def _test_DGLCSVDataset_single():
# remove original node data file to verify reload from cached files # remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0) os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0) assert not os.path.exists(nodes_csv_path_0)
csv_dataset = data.DGLCSVDataset( csv_dataset = csv_ds.DGLCSVDataset(
test_dir, force_reload=force_reload) test_dir, force_reload=force_reload)
assert len(csv_dataset) == 1 assert len(csv_dataset) == 1
g = csv_dataset[0] g = csv_dataset[0]
...@@ -799,7 +799,7 @@ def _test_DGLCSVDataset_multiple(): ...@@ -799,7 +799,7 @@ def _test_DGLCSVDataset_multiple():
# remove original node data file to verify reload from cached files # remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0) os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0) assert not os.path.exists(nodes_csv_path_0)
csv_dataset = data.DGLCSVDataset( csv_dataset = csv_ds.DGLCSVDataset(
test_dir, force_reload=force_reload) test_dir, force_reload=force_reload)
assert len(csv_dataset) == num_graphs assert len(csv_dataset) == num_graphs
assert csv_dataset.has_cache() assert csv_dataset.has_cache()
...@@ -885,7 +885,7 @@ def _test_DGLCSVDataset_customized_data_parser(): ...@@ -885,7 +885,7 @@ def _test_DGLCSVDataset_customized_data_parser():
data[header] = dt data[header] = dt
return data return data
# load CSVDataset with customized node/edge/graph_data_parser # load CSVDataset with customized node/edge/graph_data_parser
csv_dataset = data.DGLCSVDataset( csv_dataset = csv_ds.DGLCSVDataset(
test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser()) test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser())
assert len(csv_dataset) == num_graphs assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1 assert len(csv_dataset.data) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment