Unverified Commit 39121dfd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] support non-numeric node_id/src_id/dst_id/graph_id and rena… (#3740)

* [Feature] support non-numeric node_id/src_id/dst_id/graph_id and rename CSVDataset

* change return value when iterate dataset

* refine data_parser

* force reload
parent 42f8c8f3
...@@ -29,7 +29,7 @@ from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset ...@@ -29,7 +29,7 @@ from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset from .csv_dataset import CSVDataset
from .adapter import AsNodePredDataset, AsLinkPredDataset from .adapter import AsNodePredDataset, AsLinkPredDataset
def register_data_args(parser): def register_data_args(parser):
......
...@@ -5,7 +5,7 @@ from .utils import save_graphs, load_graphs ...@@ -5,7 +5,7 @@ from .utils import save_graphs, load_graphs
from ..base import DGLError from ..base import DGLError
class DGLCSVDataset(DGLDataset): class CSVDataset(DGLDataset):
""" This class aims to parse data from CSV files, construct DGLGraph """ This class aims to parse data from CSV files, construct DGLGraph
and behaves as a DGLDataset. and behaves as a DGLDataset.
...@@ -17,22 +17,27 @@ class DGLCSVDataset(DGLDataset): ...@@ -17,22 +17,27 @@ class DGLCSVDataset(DGLDataset):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool, optional verbose: bool, optional
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
node_data_parser : dict[str, callable], optional ndata_parser : dict[str, callable] or callable, optional
A dictionary used for node data parsing when loading from CSV files. Callable object which takes in the ``pandas.DataFrame`` object created from
The key is node type which specifies the header in CSV file and the CSV file, parses node data and returns a dictionary of parsed data. If given a
value is a callable object which is used to parse corresponding dictionary, the key is node type and the value is a callable object which is
column data. Default: None. If None, a default data parser is applied used to parse data of corresponding node type. If given a single callable
which load data directly and tries to convert list into array. object, such object is used to parse data of all node type data. Default: None.
edge_data_parser : dict[(str, str, str), callable], optional If None, a default data parser is applied which load data directly and tries to
A dictionary used for edge data parsing when loading from CSV files. convert list into array.
The key is edge type which specifies the header in CSV file and the edata_parser : dict[(str, str, str), callable], or callable, optional
value is a callable object which is used to parse corresponding Callable object which takes in the ``pandas.DataFrame`` object created from
column data. Default: None. If None, a default data parser is applied CSV file, parses edge data and returns a dictionary of parsed data. If given a
which load data directly and tries to convert list into array. dictionary, the key is edge type and the value is a callable object which is
graph_data_parser : callable, optional used to parse data of corresponding edge type. If given a single callable
A callable object which is used to parse corresponding column graph object, such object is used to parse data of all edge type data. Default: None.
data. Default: None. If None, a default data parser is applied If None, a default data parser is applied which load data directly and tries to
which load data directly and tries to convert list into array. convert list into array.
gdata_parser : callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses graph data and returns a dictionary of parsed data. Default:
None. If None, a default data parser is applied which load data directly and
tries to convert list into array.
transform : callable, optional transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be a transformed version. The :class:`~dgl.DGLGraph` object will be
...@@ -50,19 +55,19 @@ class DGLCSVDataset(DGLDataset): ...@@ -50,19 +55,19 @@ class DGLCSVDataset(DGLDataset):
""" """
META_YAML_NAME = 'meta.yaml' META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, def __init__(self, data_path, force_reload=False, verbose=True, ndata_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None): edata_parser=None, gdata_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None self.graphs = None
self.data = None self.data = None
self.node_data_parser = {} if node_data_parser is None else node_data_parser self.ndata_parser = {} if ndata_parser is None else ndata_parser
self.edge_data_parser = {} if edge_data_parser is None else edge_data_parser self.edata_parser = {} if edata_parser is None else edata_parser
self.graph_data_parser = graph_data_parser self.gdata_parser = gdata_parser
self.default_data_parser = DefaultDataParser() self.default_data_parser = DefaultDataParser()
meta_yaml_path = os.path.join(data_path, DGLCSVDataset.META_YAML_NAME) meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME)
if not os.path.exists(meta_yaml_path): if not os.path.exists(meta_yaml_path):
raise DGLError( raise DGLError(
"'{}' cannot be found under {}.".format(DGLCSVDataset.META_YAML_NAME, data_path)) "'{}' cannot be found under {}.".format(CSVDataset.META_YAML_NAME, data_path))
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path) self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname( super().__init__(ds_name, raw_dir=os.path.dirname(
...@@ -80,8 +85,8 @@ class DGLCSVDataset(DGLDataset): ...@@ -80,8 +85,8 @@ class DGLCSVDataset(DGLDataset):
if meta_node is None: if meta_node is None:
continue continue
ntype = meta_node.ntype ntype = meta_node.ntype
data_parser = self.node_data_parser.get( data_parser = self.ndata_parser if callable(
ntype, self.default_data_parser) self.ndata_parser) else self.ndata_parser.get(ntype, self.default_data_parser)
ndata = NodeData.load_from_csv( ndata = NodeData.load_from_csv(
meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser) meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
node_data.append(ndata) node_data.append(ndata)
...@@ -90,15 +95,15 @@ class DGLCSVDataset(DGLDataset): ...@@ -90,15 +95,15 @@ class DGLCSVDataset(DGLDataset):
if meta_edge is None: if meta_edge is None:
continue continue
etype = tuple(meta_edge.etype) etype = tuple(meta_edge.etype)
data_parser = self.edge_data_parser.get( data_parser = self.edata_parser if callable(
etype, self.default_data_parser) self.edata_parser) else self.edata_parser.get(etype, self.default_data_parser)
edata = EdgeData.load_from_csv( edata = EdgeData.load_from_csv(
meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser) meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
edge_data.append(edata) edge_data.append(edata)
graph_data = None graph_data = None
if meta_yaml.graph_data is not None: if meta_yaml.graph_data is not None:
meta_graph = meta_yaml.graph_data meta_graph = meta_yaml.graph_data
data_parser = self.default_data_parser if self.graph_data_parser is None else self.graph_data_parser data_parser = self.default_data_parser if self.gdata_parser is None else self.gdata_parser
graph_data = GraphData.load_from_csv( graph_data = GraphData.load_from_csv(
meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser) meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
# construct graphs # construct graphs
...@@ -132,8 +137,9 @@ class DGLCSVDataset(DGLDataset): ...@@ -132,8 +137,9 @@ class DGLCSVDataset(DGLDataset):
else: else:
g = self._transform(self.graphs[i]) g = self._transform(self.graphs[i])
if 'label' in self.data: if len(self.data) > 0:
return g, self.data['label'][i] data = {k: v[i] for (k, v) in self.data.items()}
return g, data
else: else:
return g return g
......
...@@ -100,12 +100,13 @@ class NodeData(BaseData): ...@@ -100,12 +100,13 @@ class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """ """ Class of node data which is used for DGLGraph construction. Internal use only. """
def __init__(self, node_id, data, type=None, graph_id=None): def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64) self.id = np.array(node_id)
self.data = data self.data = data
self.type = type if type is not None else '_V' self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full( self.graph_id = np.array(
len(node_id), 0) graph_id) if graph_id is not None else np.full(len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data}) _validate_data_length(
{**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
@staticmethod @staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','): def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
...@@ -145,13 +146,14 @@ class EdgeData(BaseData): ...@@ -145,13 +146,14 @@ class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """ """ Class of edge data which is used for DGLGraph construction. Internal use only. """
def __init__(self, src_id, dst_id, data, type=None, graph_id=None): def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64) self.src = np.array(src_id)
self.dst = np.array(dst_id, dtype=np.int64) self.dst = np.array(dst_id)
self.data = data self.data = data
self.type = type if type is not None else ('_V', '_E', '_V') self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full( self.graph_id = np.array(
len(src_id), 0) graph_id) if graph_id is not None else np.full(len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data}) _validate_data_length(
{**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
@staticmethod @staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','): def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
...@@ -195,7 +197,7 @@ class GraphData(BaseData): ...@@ -195,7 +197,7 @@ class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """ """ Class of graph data which is used for DGLGraph construction. Internal use only. """
def __init__(self, graph_id, data): def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64) self.graph_id = np.array(graph_id)
self.data = data self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data}) _validate_data_length({**{'graph_id': self.graph_id}, **self.data})
...@@ -269,7 +271,7 @@ class DGLGraphConstructor: ...@@ -269,7 +271,7 @@ class DGLGraphConstructor:
class DefaultDataParser: class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It """ Default data parser for CSVDataset. It
1. ignores any columns which does not have a header. 1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by 2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','. np.array().tolist()) if cell data is a str separated by ','.
......
...@@ -220,7 +220,7 @@ def test_extract_archive(): ...@@ -220,7 +220,7 @@ def test_extract_archive():
def _test_construct_graphs_homo(): def _test_construct_graphs_homo():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1 # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
num_nodes = 100 num_nodes = 100
num_edges = 1000 num_edges = 1000
num_dims = 3 num_dims = 3
...@@ -228,8 +228,12 @@ def _test_construct_graphs_homo(): ...@@ -228,8 +228,12 @@ def _test_construct_graphs_homo():
node_ids = np.random.choice( node_ids = np.random.choice(
np.arange(num_nodes*2), size=num_nodes, replace=False) np.arange(num_nodes*2), size=num_nodes, replace=False)
assert len(node_ids) == num_nodes assert len(node_ids) == num_nodes
# to be non-sorted
np.random.shuffle(node_ids) np.random.shuffle(node_ids)
# to be duplicated
node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes])) node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes]))
# to be non-numeric
node_ids = ['id_{}'.format(id) for id in node_ids]
t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims), t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes+num_dup_nodes)} 'label': np.random.randint(2, size=num_nodes+num_dup_nodes)}
_, u_indices = np.unique(node_ids, return_index=True) _, u_indices = np.unique(node_ids, return_index=True)
...@@ -260,7 +264,7 @@ def _test_construct_graphs_homo(): ...@@ -260,7 +264,7 @@ def _test_construct_graphs_homo():
def _test_construct_graphs_hetero(): def _test_construct_graphs_hetero():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1 # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
num_nodes = 100 num_nodes = 100
num_edges = 1000 num_edges = 1000
num_dims = 3 num_dims = 3
...@@ -273,8 +277,12 @@ def _test_construct_graphs_hetero(): ...@@ -273,8 +277,12 @@ def _test_construct_graphs_hetero():
node_ids = np.random.choice( node_ids = np.random.choice(
np.arange(num_nodes*2), size=num_nodes, replace=False) np.arange(num_nodes*2), size=num_nodes, replace=False)
assert len(node_ids) == num_nodes assert len(node_ids) == num_nodes
# to be non-sorted
np.random.shuffle(node_ids) np.random.shuffle(node_ids)
# to be duplicated
node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes])) node_ids = np.hstack((node_ids, node_ids[:num_dup_nodes]))
# to be non-numeric
node_ids = ['id_{}'.format(id) for id in node_ids]
t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims), t_ndata = {'feat': np.random.rand(num_nodes+num_dup_nodes, num_dims),
'label': np.random.randint(2, size=num_nodes+num_dup_nodes)} 'label': np.random.randint(2, size=num_nodes+num_dup_nodes)}
_, u_indices = np.unique(node_ids, return_index=True) _, u_indices = np.unique(node_ids, return_index=True)
...@@ -341,13 +349,16 @@ def _test_construct_graphs_multiple(): ...@@ -341,13 +349,16 @@ def _test_construct_graphs_multiple():
egraph_ids = np.append(egraph_ids, np.full(num_edges, i)) egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
ndata = {'feat': np.random.rand(num_nodes*num_graphs, num_dims), ndata = {'feat': np.random.rand(num_nodes*num_graphs, num_dims),
'label': np.random.randint(2, size=num_nodes*num_graphs)} 'label': np.random.randint(2, size=num_nodes*num_graphs)}
ngraph_ids = ['graph_{}'.format(id) for id in ngraph_ids]
node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids) node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
egraph_ids = ['graph_{}'.format(id) for id in egraph_ids]
edata = {'feat': np.random.rand( edata = {'feat': np.random.rand(
num_edges*num_graphs, num_dims), 'label': np.random.randint(2, size=num_edges*num_graphs)} num_edges*num_graphs, num_dims), 'label': np.random.randint(2, size=num_edges*num_graphs)}
edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids) edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
gdata = {'feat': np.random.rand(num_graphs, num_dims), gdata = {'feat': np.random.rand(num_graphs, num_dims),
'label': np.random.randint(2, size=num_graphs)} 'label': np.random.randint(2, size=num_graphs)}
graph_data = GraphData(np.arange(num_graphs), gdata) graph_ids = ['graph_{}'.format(id) for id in np.arange(num_graphs)]
graph_data = GraphData(graph_ids, gdata)
graphs, data_dict = DGLGraphConstructor.construct_graphs( graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data) node_data, edge_data, graph_data)
assert len(graphs) == num_graphs assert len(graphs) == num_graphs
...@@ -728,7 +739,7 @@ def _test_load_graph_data_from_csv(): ...@@ -728,7 +739,7 @@ def _test_load_graph_data_from_csv():
assert expect_except assert expect_except
def _test_DGLCSVDataset_single(): def _test_CSVDataset_single():
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs # generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml") meta_yaml_path = os.path.join(test_dir, "meta.yaml")
...@@ -779,7 +790,7 @@ def _test_DGLCSVDataset_single(): ...@@ -779,7 +790,7 @@ def _test_DGLCSVDataset_single():
# remove original node data file to verify reload from cached files # remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0) os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0) assert not os.path.exists(nodes_csv_path_0)
csv_dataset = data.DGLCSVDataset( csv_dataset = data.CSVDataset(
test_dir, force_reload=force_reload) test_dir, force_reload=force_reload)
assert len(csv_dataset) == 1 assert len(csv_dataset) == 1
g = csv_dataset[0] g = csv_dataset[0]
...@@ -799,7 +810,7 @@ def _test_DGLCSVDataset_single(): ...@@ -799,7 +810,7 @@ def _test_DGLCSVDataset_single():
F.asnumpy(g.edges[etype].data['label'])) F.asnumpy(g.edges[etype].data['label']))
def _test_DGLCSVDataset_multiple(): def _test_CSVDataset_multiple():
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs # generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml") meta_yaml_path = os.path.join(test_dir, "meta.yaml")
...@@ -856,13 +867,13 @@ def _test_DGLCSVDataset_multiple(): ...@@ -856,13 +867,13 @@ def _test_DGLCSVDataset_multiple():
}) })
df.to_csv(graph_csv_path, index=False) df.to_csv(graph_csv_path, index=False)
# load CSVDataset with default node/edge/graph_data_parser # load CSVDataset with default node/edge/gdata_parser
for force_reload in [True, False]: for force_reload in [True, False]:
if not force_reload: if not force_reload:
# remove original node data file to verify reload from cached files # remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0) os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0) assert not os.path.exists(nodes_csv_path_0)
csv_dataset = data.DGLCSVDataset( csv_dataset = data.CSVDataset(
test_dir, force_reload=force_reload) test_dir, force_reload=force_reload)
assert len(csv_dataset) == num_graphs assert len(csv_dataset) == num_graphs
assert csv_dataset.has_cache() assert csv_dataset.has_cache()
...@@ -871,9 +882,10 @@ def _test_DGLCSVDataset_multiple(): ...@@ -871,9 +882,10 @@ def _test_DGLCSVDataset_multiple():
assert 'label' in csv_dataset.data assert 'label' in csv_dataset.data
assert F.array_equal(F.tensor(feat_gdata), assert F.array_equal(F.tensor(feat_gdata),
csv_dataset.data['feat']) csv_dataset.data['feat'])
for i, (g, label) in enumerate(csv_dataset): for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous assert not g.is_homogeneous
assert F.asnumpy(label) == label_gdata[i] assert F.asnumpy(g_data['label']) == label_gdata[i]
assert F.array_equal(g_data['feat'], F.tensor(feat_gdata[i]))
for ntype in g.ntypes: for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes assert g.num_nodes(ntype) == num_nodes
assert F.array_equal(F.tensor(feat_ndata[i*num_nodes:(i+1)*num_nodes]), assert F.array_equal(F.tensor(feat_ndata[i*num_nodes:(i+1)*num_nodes]),
...@@ -888,7 +900,7 @@ def _test_DGLCSVDataset_multiple(): ...@@ -888,7 +900,7 @@ def _test_DGLCSVDataset_multiple():
F.asnumpy(g.edges[etype].data['label'])) F.asnumpy(g.edges[etype].data['label']))
def _test_DGLCSVDataset_customized_data_parser(): def _test_CSVDataset_customized_data_parser():
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs # generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml") meta_yaml_path = os.path.join(test_dir, "meta.yaml")
...@@ -947,25 +959,48 @@ def _test_DGLCSVDataset_customized_data_parser(): ...@@ -947,25 +959,48 @@ def _test_DGLCSVDataset_customized_data_parser():
dt += 2 dt += 2
data[header] = dt data[header] = dt
return data return data
# load CSVDataset with customized node/edge/graph_data_parser # load CSVDataset with customized node/edge/gdata_parser
csv_dataset = data.DGLCSVDataset( # specify via dict[ntype/etype, callable]
test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser()) csv_dataset = data.CSVDataset(
test_dir, force_reload=True, ndata_parser={'user': CustDataParser()},
edata_parser={('user', 'like', 'item'): CustDataParser()},
gdata_parser=CustDataParser())
assert len(csv_dataset) == num_graphs assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1 assert len(csv_dataset.data) == 1
assert 'label' in csv_dataset.data assert 'label' in csv_dataset.data
for i, (g, label) in enumerate(csv_dataset): for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous assert not g.is_homogeneous
assert F.asnumpy(label) == label_gdata[i] + 2 assert F.asnumpy(g_data['label']) == label_gdata[i] + 2
for ntype in g.ntypes: for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes assert g.num_nodes(ntype) == num_nodes
offset = 2 if ntype == 'user' else 0 offset = 2 if ntype == 'user' else 0
assert np.array_equal(label_ndata[i*num_nodes:(i+1)*num_nodes]+offset, assert np.array_equal(label_ndata[i*num_nodes:(i+1)*num_nodes]+offset,
F.asnumpy(g.nodes[ntype].data['label'])) F.asnumpy(g.nodes[ntype].data['label']))
for etype in g.etypes: for etype in g.etypes:
assert g.num_edges(etype) == num_edges assert g.num_edges(etype) == num_edges
offset = 2 if etype == 'like' else 0 offset = 2 if etype == 'like' else 0
assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges]+offset, assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges]+offset,
F.asnumpy(g.edges[etype].data['label'])) F.asnumpy(g.edges[etype].data['label']))
# specify via callable
csv_dataset = data.CSVDataset(
test_dir, force_reload=True, ndata_parser=CustDataParser(),
edata_parser=CustDataParser(), gdata_parser=CustDataParser())
assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1
assert 'label' in csv_dataset.data
for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous
assert F.asnumpy(g_data['label']) == label_gdata[i] + 2
for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes
offset = 2
assert np.array_equal(label_ndata[i*num_nodes:(i+1)*num_nodes]+offset,
F.asnumpy(g.nodes[ntype].data['label']))
for etype in g.etypes:
assert g.num_edges(etype) == num_edges
offset = 2
assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges]+offset,
F.asnumpy(g.edges[etype].data['label']))
def _test_NodeEdgeGraphData(): def _test_NodeEdgeGraphData():
...@@ -974,8 +1009,7 @@ def _test_NodeEdgeGraphData(): ...@@ -974,8 +1009,7 @@ def _test_NodeEdgeGraphData():
num_nodes = 100 num_nodes = 100
node_ids = np.arange(num_nodes, dtype=np.float) node_ids = np.arange(num_nodes, dtype=np.float)
ndata = NodeData(node_ids, {}) ndata = NodeData(node_ids, {})
assert ndata.id.dtype == np.int64 assert np.array_equal(ndata.id, node_ids)
assert np.array_equal(ndata.id, node_ids.astype(np.int64))
assert len(ndata.data) == 0 assert len(ndata.data) == 0
assert ndata.type == '_V' assert ndata.type == '_V'
assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0)) assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
...@@ -1017,8 +1051,6 @@ def _test_NodeEdgeGraphData(): ...@@ -1017,8 +1051,6 @@ def _test_NodeEdgeGraphData():
graph_ids = np.arange(num_edges) graph_ids = np.arange(num_edges)
edata = EdgeData(src_ids, dst_ids, data, edata = EdgeData(src_ids, dst_ids, data,
type=etype, graph_id=graph_ids) type=etype, graph_id=graph_ids)
assert edata.src.dtype == np.int64
assert edata.dst.dtype == np.int64
assert np.array_equal(edata.src, src_ids) assert np.array_equal(edata.src, src_ids)
assert np.array_equal(edata.dst, dst_ids) assert np.array_equal(edata.dst, dst_ids)
assert edata.type == etype assert edata.type == etype
...@@ -1046,7 +1078,6 @@ def _test_NodeEdgeGraphData(): ...@@ -1046,7 +1078,6 @@ def _test_NodeEdgeGraphData():
graph_ids = np.arange(num_graphs).astype(np.float) graph_ids = np.arange(num_graphs).astype(np.float)
data = {'feat': np.random.rand(num_graphs, 3)} data = {'feat': np.random.rand(num_graphs, 3)}
gdata = GraphData(graph_ids, data) gdata = GraphData(graph_ids, data)
assert gdata.graph_id.dtype == np.int64
assert np.array_equal(gdata.graph_id, graph_ids) assert np.array_equal(gdata.graph_id, graph_ids)
assert len(gdata.data) == len(data) assert len(gdata.data) == len(data)
for k, v in data.items(): for k, v in data.items():
...@@ -1065,9 +1096,9 @@ def test_csvdataset(): ...@@ -1065,9 +1096,9 @@ def test_csvdataset():
_test_load_node_data_from_csv() _test_load_node_data_from_csv()
_test_load_edge_data_from_csv() _test_load_edge_data_from_csv()
_test_load_graph_data_from_csv() _test_load_graph_data_from_csv()
_test_DGLCSVDataset_single() _test_CSVDataset_single()
_test_DGLCSVDataset_multiple() _test_CSVDataset_multiple()
_test_DGLCSVDataset_customized_data_parser() _test_CSVDataset_customized_data_parser()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_add_nodepred_split(): def test_add_nodepred_split():
...@@ -1181,7 +1212,7 @@ def test_as_nodepred_csvdataset(): ...@@ -1181,7 +1212,7 @@ def test_as_nodepred_csvdataset():
}) })
df.to_csv(edges_csv_path, index=False) df.to_csv(edges_csv_path, index=False)
ds = data.DGLCSVDataset(test_dir, force_reload=True) ds = data.CSVDataset(test_dir, force_reload=True)
assert 'feat' in ds[0].ndata assert 'feat' in ds[0].ndata
assert 'label' in ds[0].ndata assert 'label' in ds[0].ndata
assert 'train_mask' not in ds[0].ndata assert 'train_mask' not in ds[0].ndata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment