Unverified Commit 8db2dd33 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dataset][Fix] Allow non-numeric values and some fix on doc (#3757)

* [Fix] be able to parse ids if numeric and non-numeric values are used together

* add required package info and cache note into docstring

* duplicate node id is not allowed
parent bf649d94
......@@ -110,6 +110,8 @@ After loaded, the dataset has one graph without any features:
.. note::
Non-integer node IDs are allowed. When constructing the graph, ``CSVDataset`` will
map each raw ID to an integer ID starting from zero.
If the node IDs are already distinct integers from 0 to ``num_nodes-1``, no mapping
is applied.
.. note::
Edges are always directed. To have both directions, add reversed edges in the edge
......@@ -307,7 +309,7 @@ load graph-level features from.
node_data:
- file_name: nodes.csv
graph_data:
- file_name: graphs.csv
file_name: graphs.csv
To distinguish nodes and edges of different graphs, the ``node.csv`` and ``edge.csv`` must contain
an extra column ``graph_id``:
......
......@@ -8,6 +8,16 @@ from ..base import DGLError
class CSVDataset(DGLDataset):
"""Dataset class that loads and parses graph data from CSV files.
This class requires the following additional packages:
- pyyaml >= 5.4.1
- pandas >= 1.1.5
- pydantic >= 1.9.0
The parsed graph and feature data will be cached for faster reloading. If
the source CSV files are modified, please specify ``force_reload=True``
to re-parse from them.
Parameters
----------
data_path : str
......
......@@ -131,23 +131,25 @@ class NodeData(BaseData):
@staticmethod
def to_dict(node_data: List['NodeData']) -> dict:
# node_ids could be arbitrary numeric values, namely non-sorted, duplicated, not labeled from 0 to num_nodes-1
# node_ids could be numeric or non-numeric values, but duplication is not allowed.
node_dict = {}
for n_data in node_data:
graph_ids = np.unique(n_data.graph_id)
for graph_id in graph_ids:
idx = n_data.graph_id == graph_id
ids = n_data.id[idx]
u_ids, u_indices = np.unique(ids, return_index=True)
u_ids, u_indices, u_counts = np.unique(
ids, return_index=True, return_counts=True)
if len(ids) > len(u_ids):
dgl_warning(
"There exist duplicated ids and only the first ones are kept.")
raise DGLError("Node IDs are required to be unique but the following ids are duplicate: {}".format(
u_ids[u_counts > 1]))
if graph_id not in node_dict:
node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i,
index in enumerate(ids[u_indices])},
'data': {k: _tensor(v[idx][u_indices])
for k, v in n_data.data.items()}}
for k, v in n_data.data.items()},
'dtype': ids.dtype}
return node_dict
......@@ -192,8 +194,10 @@ class EdgeData(BaseData):
idx = e_data.graph_id == graph_id
src_mapping = node_dict[graph_id][src_type]['mapping']
dst_mapping = node_dict[graph_id][dst_type]['mapping']
src_ids = [src_mapping[index] for index in e_data.src[idx]]
dst_ids = [dst_mapping[index] for index in e_data.dst[idx]]
orig_src_ids = e_data.src[idx].astype(node_dict[graph_id][src_type]['dtype'])
orig_dst_ids = e_data.dst[idx].astype(node_dict[graph_id][dst_type]['dtype'])
src_ids = [src_mapping[index] for index in orig_src_ids]
dst_ids = [dst_mapping[index] for index in orig_dst_ids]
if graph_id not in edge_dict:
edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (_tensor(src_ids), _tensor(dst_ids)),
......
......@@ -218,24 +218,94 @@ def test_extract_archive():
assert os.path.exists(os.path.join(dst_dir, gz_file))
def _test_construct_graphs_node_ids():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
num_nodes = 100
num_edges = 1000
# node IDs are required to be unique
node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)
src_ids = np.random.choice(node_ids, size=num_edges)
dst_ids = np.random.choice(node_ids, size=num_edges)
node_data = NodeData(node_ids, {})
edge_data = EdgeData(src_ids, dst_ids, {})
expect_except = False
try:
_, _ = DGLGraphConstructor.construct_graphs(
node_data, edge_data)
except:
expect_except = True
assert expect_except
# node IDs are already labelled from 0~num_nodes-1
node_ids = np.arange(num_nodes)
np.random.shuffle(node_ids)
_, idx = np.unique(node_ids, return_index=True)
src_ids = np.random.choice(node_ids, size=num_edges)
dst_ids = np.random.choice(node_ids, size=num_edges)
node_feat = np.random.rand(num_nodes, 3)
node_data = NodeData(node_ids, {'feat':node_feat})
edge_data = EdgeData(src_ids, dst_ids, {})
graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data)
assert len(graphs) == 1
assert len(data_dict) == 0
g = graphs[0]
assert g.is_homogeneous
assert g.num_nodes() == len(node_ids)
assert g.num_edges() == len(src_ids)
assert F.array_equal(F.tensor(node_feat[idx], dtype=F.float32), g.ndata['feat'])
# node IDs are mixed with numeric and non-numeric values
# homogeneous graph
node_ids = [1, 2, 3, 'a']
src_ids = [1, 2, 3]
dst_ids = ['a', 1, 2]
node_data = NodeData(node_ids, {})
edge_data = EdgeData(src_ids, dst_ids, {})
graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data)
assert len(graphs) == 1
assert len(data_dict) == 0
g = graphs[0]
assert g.is_homogeneous
assert g.num_nodes() == len(node_ids)
assert g.num_edges() == len(src_ids)
# heterogeneous graph
node_ids_user = [1, 2, 3]
node_ids_item = ['a', 'b', 'c']
src_ids = node_ids_user
dst_ids = node_ids_item
node_data_user = NodeData(node_ids_user, {}, type='user')
node_data_item = NodeData(node_ids_item, {}, type='item')
edge_data = EdgeData(src_ids, dst_ids, {}, type=('user', 'like', 'item'))
graphs, data_dict = DGLGraphConstructor.construct_graphs(
[node_data_user, node_data_item], edge_data)
assert len(graphs) == 1
assert len(data_dict) == 0
g = graphs[0]
assert not g.is_homogeneous
assert g.num_nodes('user') == len(node_ids_user)
assert g.num_nodes('item') == len(node_ids_item)
assert g.num_edges() == len(src_ids)
def _test_construct_graphs_homo():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
# node_id could be non-sorted, non-numeric.
num_nodes = 100
num_edges = 1000
num_dims = 3
num_dup_nodes = int(num_nodes*0.2)
node_ids = np.random.choice(
np.arange(num_nodes*2), size=num_nodes, replace=False)
assert len(node_ids) == num_nodes
# to be non-sorted
np.random.shuffle(node_ids)
# to be duplicated
node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes]))
# to be non-numeric
node_ids = ['id_{}'.format(id) for id in node_ids]
t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes+num_dup_nodes)}
t_ndata = {'feat': np.random.rand(num_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes)}
_, u_indices = np.unique(node_ids, return_index=True)
ndata = {'feat': t_ndata['feat'][u_indices],
'label': t_ndata['label'][u_indices]}
......@@ -270,7 +340,6 @@ def _test_construct_graphs_hetero():
num_nodes = 100
num_edges = 1000
num_dims = 3
num_dup_nodes = int(num_nodes*0.2)
ntypes = ['user', 'item']
node_data = []
node_ids_dict = {}
......@@ -281,12 +350,10 @@ def _test_construct_graphs_hetero():
assert len(node_ids) == num_nodes
# to be non-sorted
np.random.shuffle(node_ids)
# to be duplicated
node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes]))
# to be non-numeric
node_ids = ['id_{}'.format(id) for id in node_ids]
t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes+num_dup_nodes)}
t_ndata = {'feat': np.random.rand(num_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes)}
_, u_indices = np.unique(node_ids, return_index=True)
ndata = {'feat': t_ndata['feat'][u_indices],
'label': t_ndata['label'][u_indices]}
......@@ -1095,6 +1162,7 @@ def _test_NodeEdgeGraphData():
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_csvdataset():
_test_NodeEdgeGraphData()
_test_construct_graphs_node_ids()
_test_construct_graphs_homo()
_test_construct_graphs_hetero()
_test_construct_graphs_multiple()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment