Unverified Commit 39121dfd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] support non-numeric node_id/src_id/dst_id/graph_id and rena… (#3740)

* [Feature] support non-numeric node_id/src_id/dst_id/graph_id and rename CSVDataset

* change return value when iterate dataset

* refine data_parser

* force reload
parent 42f8c8f3
......@@ -29,7 +29,7 @@ from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset
from .csv_dataset import CSVDataset
from .adapter import AsNodePredDataset, AsLinkPredDataset
def register_data_args(parser):
......
......@@ -5,7 +5,7 @@ from .utils import save_graphs, load_graphs
from ..base import DGLError
class DGLCSVDataset(DGLDataset):
class CSVDataset(DGLDataset):
""" This class aims to parse data from CSV files, construct DGLGraph
and behaves as a DGLDataset.
......@@ -17,22 +17,27 @@ class DGLCSVDataset(DGLDataset):
Whether to reload the dataset. Default: False
verbose: bool, optional
Whether to print out progress information. Default: True.
node_data_parser : dict[str, callable], optional
A dictionary used for node data parsing when loading from CSV files.
The key is node type which specifies the header in CSV file and the
value is a callable object which is used to parse corresponding
column data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
edge_data_parser : dict[(str, str, str), callable], optional
A dictionary used for edge data parsing when loading from CSV files.
The key is edge type which specifies the header in CSV file and the
value is a callable object which is used to parse corresponding
column data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
graph_data_parser : callable, optional
A callable object which is used to parse corresponding column graph
data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
ndata_parser : dict[str, callable] or callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses node data and returns a dictionary of parsed data. If given a
dictionary, the key is node type and the value is a callable object which is
used to parse data of corresponding node type. If given a single callable
object, such object is used to parse data of all node type data. Default: None.
If None, a default data parser is applied which load data directly and tries to
convert list into array.
edata_parser : dict[(str, str, str), callable], or callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses edge data and returns a dictionary of parsed data. If given a
dictionary, the key is edge type and the value is a callable object which is
used to parse data of corresponding edge type. If given a single callable
object, such object is used to parse data of all edge type data. Default: None.
If None, a default data parser is applied which load data directly and tries to
convert list into array.
gdata_parser : callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses graph data and returns a dictionary of parsed data. Default:
None. If None, a default data parser is applied which load data directly and
tries to convert list into array.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
......@@ -50,19 +55,19 @@ class DGLCSVDataset(DGLDataset):
"""
META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None):
def __init__(self, data_path, force_reload=False, verbose=True, ndata_parser=None,
edata_parser=None, gdata_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None
self.data = None
self.node_data_parser = {} if node_data_parser is None else node_data_parser
self.edge_data_parser = {} if edge_data_parser is None else edge_data_parser
self.graph_data_parser = graph_data_parser
self.ndata_parser = {} if ndata_parser is None else ndata_parser
self.edata_parser = {} if edata_parser is None else edata_parser
self.gdata_parser = gdata_parser
self.default_data_parser = DefaultDataParser()
meta_yaml_path = os.path.join(data_path, DGLCSVDataset.META_YAML_NAME)
meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME)
if not os.path.exists(meta_yaml_path):
raise DGLError(
"'{}' cannot be found under {}.".format(DGLCSVDataset.META_YAML_NAME, data_path))
"'{}' cannot be found under {}.".format(CSVDataset.META_YAML_NAME, data_path))
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname(
......@@ -80,8 +85,8 @@ class DGLCSVDataset(DGLDataset):
if meta_node is None:
continue
ntype = meta_node.ntype
data_parser = self.node_data_parser.get(
ntype, self.default_data_parser)
data_parser = self.ndata_parser if callable(
self.ndata_parser) else self.ndata_parser.get(ntype, self.default_data_parser)
ndata = NodeData.load_from_csv(
meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
node_data.append(ndata)
......@@ -90,15 +95,15 @@ class DGLCSVDataset(DGLDataset):
if meta_edge is None:
continue
etype = tuple(meta_edge.etype)
data_parser = self.edge_data_parser.get(
etype, self.default_data_parser)
data_parser = self.edata_parser if callable(
self.edata_parser) else self.edata_parser.get(etype, self.default_data_parser)
edata = EdgeData.load_from_csv(
meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
edge_data.append(edata)
graph_data = None
if meta_yaml.graph_data is not None:
meta_graph = meta_yaml.graph_data
data_parser = self.default_data_parser if self.graph_data_parser is None else self.graph_data_parser
data_parser = self.default_data_parser if self.gdata_parser is None else self.gdata_parser
graph_data = GraphData.load_from_csv(
meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
# construct graphs
......@@ -132,8 +137,9 @@ class DGLCSVDataset(DGLDataset):
else:
g = self._transform(self.graphs[i])
if 'label' in self.data:
return g, self.data['label'][i]
if len(self.data) > 0:
data = {k: v[i] for (k, v) in self.data.items()}
return g, data
else:
return g
......
......@@ -100,12 +100,13 @@ class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """
def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64)
self.id = np.array(node_id)
self.data = data
self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
self.graph_id = np.array(
graph_id) if graph_id is not None else np.full(len(node_id), 0)
_validate_data_length(
{**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
......@@ -145,13 +146,14 @@ class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """
def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64)
self.dst = np.array(dst_id, dtype=np.int64)
self.src = np.array(src_id)
self.dst = np.array(dst_id)
self.data = data
self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
self.graph_id = np.array(
graph_id) if graph_id is not None else np.full(len(src_id), 0)
_validate_data_length(
{**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
......@@ -195,7 +197,7 @@ class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """
def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64)
self.graph_id = np.array(graph_id)
self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data})
......@@ -269,7 +271,7 @@ class DGLGraphConstructor:
class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It
""" Default data parser for CSVDataset. It
1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','.
......
......@@ -220,7 +220,7 @@ def test_extract_archive():
def _test_construct_graphs_homo():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1
# node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
num_nodes = 100
num_edges = 1000
num_dims = 3
......@@ -228,8 +228,12 @@ def _test_construct_graphs_homo():
node_ids = np.random.choice(
np.arange(num_nodes*2), size=num_nodes, replace=False)
assert len(node_ids) == num_nodes
# to be non-sorted
np.random.shuffle(node_ids)
# to be duplicated
node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes]))
# to be non-numeric
node_ids = ['id_{}'.format(id) for id in node_ids]
t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes+num_dup_nodes)}
_, u_indices = np.unique(node_ids, return_index=True)
......@@ -260,7 +264,7 @@ def _test_construct_graphs_homo():
def _test_construct_graphs_hetero():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1
# node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
num_nodes = 100
num_edges = 1000
num_dims = 3
......@@ -273,8 +277,12 @@ def _test_construct_graphs_hetero():
node_ids = np.random.choice(
np.arange(num_nodes*2), size=num_nodes, replace=False)
assert len(node_ids) == num_nodes
# to be non-sorted
np.random.shuffle(node_ids)
# to be duplicated
node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes]))
# to be non-numeric
node_ids = ['id_{}'.format(id) for id in node_ids]
t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes+num_dup_nodes)}
_, u_indices = np.unique(node_ids, return_index=True)
......@@ -341,13 +349,16 @@ def _test_construct_graphs_multiple():
egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
ndata = {'feat': np.random.rand(num_nodes*num_graphs, num_dims),
'label': np.random.randint(2, size=num_nodes*num_graphs)}
ngraph_ids = ['graph_{}'.format(id) for id in ngraph_ids]
node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
egraph_ids = ['graph_{}'.format(id) for id in egraph_ids]
edata = {'feat': np.random.rand(
num_edges*num_graphs, num_dims), 'label': np.random.randint(2, size=num_edges*num_graphs)}
edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
gdata = {'feat': np.random.rand(num_graphs, num_dims),
'label': np.random.randint(2, size=num_graphs)}
graph_data = GraphData(np.arange(num_graphs), gdata)
graph_ids = ['graph_{}'.format(id) for id in np.arange(num_graphs)]
graph_data = GraphData(graph_ids, gdata)
graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data)
assert len(graphs) == num_graphs
......@@ -728,7 +739,7 @@ def _test_load_graph_data_from_csv():
assert expect_except
def _test_DGLCSVDataset_single():
def _test_CSVDataset_single():
with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml")
......@@ -779,7 +790,7 @@ def _test_DGLCSVDataset_single():
# remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0)
csv_dataset = data.DGLCSVDataset(
csv_dataset = data.CSVDataset(
test_dir, force_reload=force_reload)
assert len(csv_dataset) == 1
g = csv_dataset[0]
......@@ -799,7 +810,7 @@ def _test_DGLCSVDataset_single():
F.asnumpy(g.edges[etype].data['label']))
def _test_DGLCSVDataset_multiple():
def _test_CSVDataset_multiple():
with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml")
......@@ -856,13 +867,13 @@ def _test_DGLCSVDataset_multiple():
})
df.to_csv(graph_csv_path, index=False)
# load CSVDataset with default node/edge/graph_data_parser
# load CSVDataset with default node/edge/gdata_parser
for force_reload in [True, False]:
if not force_reload:
# remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0)
csv_dataset = data.DGLCSVDataset(
csv_dataset = data.CSVDataset(
test_dir, force_reload=force_reload)
assert len(csv_dataset) == num_graphs
assert csv_dataset.has_cache()
......@@ -871,9 +882,10 @@ def _test_DGLCSVDataset_multiple():
assert 'label' in csv_dataset.data
assert F.array_equal(F.tensor(feat_gdata),
csv_dataset.data['feat'])
for i, (g, label) in enumerate(csv_dataset):
for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous
assert F.asnumpy(label) == label_gdata[i]
assert F.asnumpy(g_data['label']) == label_gdata[i]
assert F.array_equal(g_data['feat'], F.tensor(feat_gdata[i]))
for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes
assert F.array_equal(F.tensor(feat_ndata[i*num_nodes:(i+1)*num_nodes]),
......@@ -888,7 +900,7 @@ def _test_DGLCSVDataset_multiple():
F.asnumpy(g.edges[etype].data['label']))
def _test_DGLCSVDataset_customized_data_parser():
def _test_CSVDataset_customized_data_parser():
with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml")
......@@ -947,15 +959,18 @@ def _test_DGLCSVDataset_customized_data_parser():
dt += 2
data[header] = dt
return data
# load CSVDataset with customized node/edge/graph_data_parser
csv_dataset = data.DGLCSVDataset(
test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser())
# load CSVDataset with customized node/edge/gdata_parser
# specify via dict[ntype/etype, callable]
csv_dataset = data.CSVDataset(
test_dir, force_reload=True, ndata_parser={'user': CustDataParser()},
edata_parser={('user', 'like', 'item'): CustDataParser()},
gdata_parser=CustDataParser())
assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1
assert 'label' in csv_dataset.data
for i, (g, label) in enumerate(csv_dataset):
for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous
assert F.asnumpy(label) == label_gdata[i] + 2
assert F.asnumpy(g_data['label']) == label_gdata[i] + 2
for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes
offset = 2 if ntype == 'user' else 0
......@@ -966,6 +981,26 @@ def _test_DGLCSVDataset_customized_data_parser():
offset = 2 if etype == 'like' else 0
assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges]+offset,
F.asnumpy(g.edges[etype].data['label']))
# specify via callable
csv_dataset = data.CSVDataset(
test_dir, force_reload=True, ndata_parser=CustDataParser(),
edata_parser=CustDataParser(), gdata_parser=CustDataParser())
assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1
assert 'label' in csv_dataset.data
for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous
assert F.asnumpy(g_data['label']) == label_gdata[i] + 2
for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes
offset = 2
assert np.array_equal(label_ndata[i*num_nodes:(i+1)*num_nodes]+offset,
F.asnumpy(g.nodes[ntype].data['label']))
for etype in g.etypes:
assert g.num_edges(etype) == num_edges
offset = 2
assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges]+offset,
F.asnumpy(g.edges[etype].data['label']))
def _test_NodeEdgeGraphData():
......@@ -974,8 +1009,7 @@ def _test_NodeEdgeGraphData():
num_nodes = 100
node_ids = np.arange(num_nodes, dtype=np.float)
ndata = NodeData(node_ids, {})
assert ndata.id.dtype == np.int64
assert np.array_equal(ndata.id, node_ids.astype(np.int64))
assert np.array_equal(ndata.id, node_ids)
assert len(ndata.data) == 0
assert ndata.type == '_V'
assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
......@@ -1017,8 +1051,6 @@ def _test_NodeEdgeGraphData():
graph_ids = np.arange(num_edges)
edata = EdgeData(src_ids, dst_ids, data,
type=etype, graph_id=graph_ids)
assert edata.src.dtype == np.int64
assert edata.dst.dtype == np.int64
assert np.array_equal(edata.src, src_ids)
assert np.array_equal(edata.dst, dst_ids)
assert edata.type == etype
......@@ -1046,7 +1078,6 @@ def _test_NodeEdgeGraphData():
graph_ids = np.arange(num_graphs).astype(np.float)
data = {'feat': np.random.rand(num_graphs, 3)}
gdata = GraphData(graph_ids, data)
assert gdata.graph_id.dtype == np.int64
assert np.array_equal(gdata.graph_id, graph_ids)
assert len(gdata.data) == len(data)
for k, v in data.items():
......@@ -1065,9 +1096,9 @@ def test_csvdataset():
_test_load_node_data_from_csv()
_test_load_edge_data_from_csv()
_test_load_graph_data_from_csv()
_test_DGLCSVDataset_single()
_test_DGLCSVDataset_multiple()
_test_DGLCSVDataset_customized_data_parser()
_test_CSVDataset_single()
_test_CSVDataset_multiple()
_test_CSVDataset_customized_data_parser()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_add_nodepred_split():
......@@ -1181,7 +1212,7 @@ def test_as_nodepred_csvdataset():
})
df.to_csv(edges_csv_path, index=False)
ds = data.DGLCSVDataset(test_dir, force_reload=True)
ds = data.CSVDataset(test_dir, force_reload=True)
assert 'feat' in ds[0].ndata
assert 'label' in ds[0].ndata
assert 'train_mask' not in ds[0].ndata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment