Unverified Commit 9ec9df57 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Fix] check and load dependencies when needed (#3655)

* [Fix] check and load dependencies when needed

* refine rdflib import
parent 77f4287a
...@@ -29,7 +29,7 @@ from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset ...@@ -29,7 +29,7 @@ from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset
def register_data_args(parser): def register_data_args(parser):
parser.add_argument( parser.add_argument(
......
import os import os
import yaml
from yaml.loader import SafeLoader
import pandas as pd
import numpy as np import numpy as np
from typing import List, Optional
import pydantic as dt
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from ..convert import heterograph as dgl_heterograph
from .. import backend as F
from .utils import save_graphs, load_graphs from .utils import save_graphs, load_graphs
from ..base import dgl_warning, DGLError from ..base import DGLError
import abc
import ast
from typing import Callable
class MetaNode(dt.BaseModel):
""" Class of node_data in YAML. Internal use only. """
file_name: str
ntype: Optional[str] = '_V'
graph_id_field: Optional[str] = 'graph_id'
node_id_field: Optional[str] = 'node_id'
class MetaEdge(dt.BaseModel):
""" Class of edge_data in YAML. Internal use only. """
file_name: str
etype: Optional[List[str]] = ['_V', '_E', '_V']
graph_id_field: Optional[str] = 'graph_id'
src_id_field: Optional[str] = 'src_id'
dst_id_field: Optional[str] = 'dst_id'
class MetaGraph(dt.BaseModel):
""" Class of graph_data in YAML. Internal use only. """
file_name: str
graph_id_field: Optional[str] = 'graph_id'
class MetaYaml(dt.BaseModel):
""" Class of YAML. Internal use only. """
version: Optional[str] = '1.0.0'
dataset_name: str
separator: Optional[str] = ','
node_data: List[MetaNode]
edge_data: List[MetaEdge]
graph_data: Optional[MetaGraph] = None
def load_yaml_with_sanity_check(yaml_file):
""" Load yaml and do sanity check. Internal use only. """
with open(yaml_file) as f:
yaml_data = yaml.load(f, Loader=SafeLoader)
try:
meta_yaml = MetaYaml(**yaml_data)
except dt.ValidationError as e:
print(
"Details of pydantic.ValidationError:\n{}".format(e.json()))
raise DGLError(
"Validation Error for YAML fields. Details are shown above.")
if meta_yaml.version != '1.0.0':
raise DGLError("Invalid CSVDataset version {}. Supported versions: '1.0.0'".format(
meta_yaml.version))
ntypes = [meta.ntype for meta in meta_yaml.node_data]
if len(ntypes) > len(set(ntypes)):
raise DGLError(
"Each node CSV file must have a unique node type name, but found duplicate node type: {}.".format(ntypes))
etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data]
if len(etypes) > len(set(etypes)):
raise DGLError(
"Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.".format(etypes))
return meta_yaml
def _validate_data_length(data_dict):
len_dict = {k: len(v) for k, v in data_dict.items()}
lst = list(len_dict.values())
res = lst.count(lst[0]) == len(lst)
if not res:
raise DGLError(
"All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict)))
class BaseData:
""" Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """
@staticmethod
def read_csv(file_name, base_dir, separator):
csv_path = file_name
if base_dir is not None:
csv_path = os.path.join(base_dir, csv_path)
return pd.read_csv(csv_path, sep=separator)
@staticmethod
def pop_from_dataframe(df: pd.DataFrame, item: str):
ret = None
try:
ret = df.pop(item).to_numpy().squeeze()
except KeyError:
pass
return ret
class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """
def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if node_ids is None:
raise DGLError("Missing node id field [{}] in file [{}].".format(
meta.node_id_field, meta.file_name))
ntype = meta.ntype
ndata = data_parser(df)
return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)
@staticmethod
def to_dict(node_data: List['NodeData']) -> dict:
# node_ids could be arbitrary numeric values, namely non-sorted, duplicated, not labeled from 0 to num_nodes-1
node_dict = {}
for n_data in node_data:
graph_ids = np.unique(n_data.graph_id)
for graph_id in graph_ids:
idx = n_data.graph_id == graph_id
ids = n_data.id[idx]
u_ids, u_indices = np.unique(ids, return_index=True)
if len(ids) > len(u_ids):
dgl_warning(
"There exist duplicated ids and only the first ones are kept.")
if graph_id not in node_dict:
node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i,
index in enumerate(ids[u_indices])},
'data': {k: F.tensor(v[idx][u_indices])
for k, v in n_data.data.items()}}
return node_dict
class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """
def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64)
self.dst = np.array(dst_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field)
if src_ids is None:
raise DGLError("Missing src id field [{}] in file [{}].".format(
meta.src_id_field, meta.file_name))
dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field)
if dst_ids is None:
raise DGLError("Missing dst id field [{}] in file [{}].".format(
meta.dst_id_field, meta.file_name))
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
etype = tuple(meta.etype)
edata = data_parser(df)
return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids)
@staticmethod
def to_dict(edge_data: List['EdgeData'], node_dict: dict) -> dict:
edge_dict = {}
for e_data in edge_data:
(src_type, e_type, dst_type) = e_data.type
graph_ids = np.unique(e_data.graph_id)
for graph_id in graph_ids:
if graph_id in edge_dict and e_data.type in edge_dict[graph_id]:
raise DGLError(f"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData.")
idx = e_data.graph_id == graph_id
src_mapping = node_dict[graph_id][src_type]['mapping']
dst_mapping = node_dict[graph_id][dst_type]['mapping']
src_ids = [src_mapping[index] for index in e_data.src[idx]]
dst_ids = [dst_mapping[index] for index in e_data.dst[idx]]
if graph_id not in edge_dict:
edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (F.tensor(src_ids), F.tensor(dst_ids)),
'data': {k: F.tensor(v[idx])
for k, v in e_data.data.items()}}
return edge_dict
class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """
def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64)
self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaGraph, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if graph_ids is None:
raise DGLError("Missing graph id field [{}] in file [{}].".format(
meta.graph_id_field, meta.file_name))
gdata = data_parser(df)
return GraphData(graph_ids, gdata)
@staticmethod
def to_dict(graph_data: 'GraphData', graphs_dict: dict) -> dict:
missing_ids = np.setdiff1d(
np.array(list(graphs_dict.keys())), graph_data.graph_id)
if len(missing_ids) > 0:
raise DGLError(
"Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format(missing_ids))
graph_ids = graph_data.graph_id
graphs = []
for graph_id in graph_ids:
if graph_id not in graphs_dict:
graphs_dict[graph_id] = dgl_heterograph(
{('_V', '_E', '_V'): ([], [])})
for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id])
data = {k: F.tensor(v) for k, v in graph_data.data.items()}
return graphs, data
class DGLGraphConstructor:
""" Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only. """
@staticmethod
def construct_graphs(node_data, edge_data, graph_data=None):
if not isinstance(node_data, list):
node_data = [node_data]
if not isinstance(edge_data, list):
edge_data = [edge_data]
node_dict = NodeData.to_dict(node_data)
edge_dict = EdgeData.to_dict(edge_data, node_dict)
graph_dict = DGLGraphConstructor._construct_graphs(
node_dict, edge_dict)
if graph_data is None:
graph_data = GraphData(np.full(1, 0), {})
graphs, data = GraphData.to_dict(
graph_data, graph_dict)
return graphs, data
@staticmethod
def _construct_graphs(node_dict, edge_dict):
graph_dict = {}
for graph_id in node_dict:
if graph_id not in edge_dict:
edge_dict[graph_id][('_V', '_E', '_V')] = {'edges': ([], [])}
graph = dgl_heterograph({etype: edata['edges']
for etype, edata in edge_dict[graph_id].items()},
num_nodes_dict={ntype: len(ndata['mapping'])
for ntype, ndata in node_dict[graph_id].items()})
def assign_data(type, src_data, dst_data):
for key, value in src_data.items():
dst_data[type].data[key] = value
for type, data in node_dict[graph_id].items():
assign_data(type, data['data'], graph.nodes)
for (type), data in edge_dict[graph_id].items():
assign_data(type, data['data'], graph.edges)
graph_dict[graph_id] = graph
return graph_dict
class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It
1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','.
3. read data and infer data type directly, otherwise.
"""
def __call__(self, df: pd.DataFrame):
data = {}
for header in df:
if 'Unnamed' in header:
dgl_warning("Unamed column is found. Ignored...")
continue
dt = df[header].to_numpy().squeeze()
if len(dt) > 0 and isinstance(dt[0], str):
#probably consists of list of numeric values
dt = np.array([ast.literal_eval(row) for row in dt])
data[header] = dt
return data
class DGLCSVDataset(DGLDataset): class DGLCSVDataset(DGLDataset):
...@@ -337,6 +47,7 @@ class DGLCSVDataset(DGLDataset): ...@@ -337,6 +47,7 @@ class DGLCSVDataset(DGLDataset):
META_YAML_NAME = 'meta.yaml' META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None): def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None self.graphs = None
self.data = None self.data = None
self.node_data_parser = {} if node_data_parser is None else node_data_parser self.node_data_parser = {} if node_data_parser is None else node_data_parser
...@@ -352,9 +63,11 @@ class DGLCSVDataset(DGLDataset): ...@@ -352,9 +63,11 @@ class DGLCSVDataset(DGLDataset):
super().__init__(ds_name, raw_dir=os.path.dirname( super().__init__(ds_name, raw_dir=os.path.dirname(
meta_yaml_path), force_reload=force_reload, verbose=verbose) meta_yaml_path), force_reload=force_reload, verbose=verbose)
def process(self): def process(self):
"""Parse node/edge data from CSV files and construct DGL.Graphs """Parse node/edge data from CSV files and construct DGL.Graphs
""" """
from .csv_dataset_base import NodeData, EdgeData, GraphData, DGLGraphConstructor
meta_yaml = self.meta_yaml meta_yaml = self.meta_yaml
base_dir = self.raw_dir base_dir = self.raw_dir
node_data = [] node_data = []
......
import os
import numpy as np
from typing import List, Optional, Callable
from .. import backend as F
from ..convert import heterograph as dgl_heterograph
from ..base import dgl_warning, DGLError
import ast
import pydantic as dt
import pandas as pd
import yaml
class MetaNode(dt.BaseModel):
""" Class of node_data in YAML. Internal use only. """
file_name: str
ntype: Optional[str] = '_V'
graph_id_field: Optional[str] = 'graph_id'
node_id_field: Optional[str] = 'node_id'
class MetaEdge(dt.BaseModel):
""" Class of edge_data in YAML. Internal use only. """
file_name: str
etype: Optional[List[str]] = ['_V', '_E', '_V']
graph_id_field: Optional[str] = 'graph_id'
src_id_field: Optional[str] = 'src_id'
dst_id_field: Optional[str] = 'dst_id'
class MetaGraph(dt.BaseModel):
""" Class of graph_data in YAML. Internal use only. """
file_name: str
graph_id_field: Optional[str] = 'graph_id'
class MetaYaml(dt.BaseModel):
""" Class of YAML. Internal use only. """
version: Optional[str] = '1.0.0'
dataset_name: str
separator: Optional[str] = ','
node_data: List[MetaNode]
edge_data: List[MetaEdge]
graph_data: Optional[MetaGraph] = None
def load_yaml_with_sanity_check(yaml_file):
""" Load yaml and do sanity check. Internal use only. """
with open(yaml_file) as f:
yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
try:
meta_yaml = MetaYaml(**yaml_data)
except dt.ValidationError as e:
print(
"Details of pydantic.ValidationError:\n{}".format(e.json()))
raise DGLError(
"Validation Error for YAML fields. Details are shown above.")
if meta_yaml.version != '1.0.0':
raise DGLError("Invalid CSVDataset version {}. Supported versions: '1.0.0'".format(
meta_yaml.version))
ntypes = [meta.ntype for meta in meta_yaml.node_data]
if len(ntypes) > len(set(ntypes)):
raise DGLError(
"Each node CSV file must have a unique node type name, but found duplicate node type: {}.".format(ntypes))
etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data]
if len(etypes) > len(set(etypes)):
raise DGLError(
"Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.".format(etypes))
return meta_yaml
def _validate_data_length(data_dict):
len_dict = {k: len(v) for k, v in data_dict.items()}
lst = list(len_dict.values())
res = lst.count(lst[0]) == len(lst)
if not res:
raise DGLError(
"All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict)))
class BaseData:
""" Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """
@staticmethod
def read_csv(file_name, base_dir, separator):
csv_path = file_name
if base_dir is not None:
csv_path = os.path.join(base_dir, csv_path)
return pd.read_csv(csv_path, sep=separator)
@staticmethod
def pop_from_dataframe(df: pd.DataFrame, item: str):
ret = None
try:
ret = df.pop(item).to_numpy().squeeze()
except KeyError:
pass
return ret
class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """
def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if node_ids is None:
raise DGLError("Missing node id field [{}] in file [{}].".format(
meta.node_id_field, meta.file_name))
ntype = meta.ntype
ndata = data_parser(df)
return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)
@staticmethod
def to_dict(node_data: List['NodeData']) -> dict:
# node_ids could be arbitrary numeric values, namely non-sorted, duplicated, not labeled from 0 to num_nodes-1
node_dict = {}
for n_data in node_data:
graph_ids = np.unique(n_data.graph_id)
for graph_id in graph_ids:
idx = n_data.graph_id == graph_id
ids = n_data.id[idx]
u_ids, u_indices = np.unique(ids, return_index=True)
if len(ids) > len(u_ids):
dgl_warning(
"There exist duplicated ids and only the first ones are kept.")
if graph_id not in node_dict:
node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i,
index in enumerate(ids[u_indices])},
'data': {k: F.tensor(v[idx][u_indices])
for k, v in n_data.data.items()}}
return node_dict
class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """
def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64)
self.dst = np.array(dst_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field)
if src_ids is None:
raise DGLError("Missing src id field [{}] in file [{}].".format(
meta.src_id_field, meta.file_name))
dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field)
if dst_ids is None:
raise DGLError("Missing dst id field [{}] in file [{}].".format(
meta.dst_id_field, meta.file_name))
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
etype = tuple(meta.etype)
edata = data_parser(df)
return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids)
@staticmethod
def to_dict(edge_data: List['EdgeData'], node_dict: dict) -> dict:
edge_dict = {}
for e_data in edge_data:
(src_type, e_type, dst_type) = e_data.type
graph_ids = np.unique(e_data.graph_id)
for graph_id in graph_ids:
if graph_id in edge_dict and e_data.type in edge_dict[graph_id]:
raise DGLError(f"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData.")
idx = e_data.graph_id == graph_id
src_mapping = node_dict[graph_id][src_type]['mapping']
dst_mapping = node_dict[graph_id][dst_type]['mapping']
src_ids = [src_mapping[index] for index in e_data.src[idx]]
dst_ids = [dst_mapping[index] for index in e_data.dst[idx]]
if graph_id not in edge_dict:
edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (F.tensor(src_ids), F.tensor(dst_ids)),
'data': {k: F.tensor(v[idx])
for k, v in e_data.data.items()}}
return edge_dict
class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """
def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64)
self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaGraph, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if graph_ids is None:
raise DGLError("Missing graph id field [{}] in file [{}].".format(
meta.graph_id_field, meta.file_name))
gdata = data_parser(df)
return GraphData(graph_ids, gdata)
@staticmethod
def to_dict(graph_data: 'GraphData', graphs_dict: dict) -> dict:
missing_ids = np.setdiff1d(
np.array(list(graphs_dict.keys())), graph_data.graph_id)
if len(missing_ids) > 0:
raise DGLError(
"Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format(missing_ids))
graph_ids = graph_data.graph_id
graphs = []
for graph_id in graph_ids:
if graph_id not in graphs_dict:
graphs_dict[graph_id] = dgl_heterograph(
{('_V', '_E', '_V'): ([], [])})
for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id])
data = {k: F.tensor(v) for k, v in graph_data.data.items()}
return graphs, data
class DGLGraphConstructor:
""" Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only. """
@staticmethod
def construct_graphs(node_data, edge_data, graph_data=None):
if not isinstance(node_data, list):
node_data = [node_data]
if not isinstance(edge_data, list):
edge_data = [edge_data]
node_dict = NodeData.to_dict(node_data)
edge_dict = EdgeData.to_dict(edge_data, node_dict)
graph_dict = DGLGraphConstructor._construct_graphs(
node_dict, edge_dict)
if graph_data is None:
graph_data = GraphData(np.full(1, 0), {})
graphs, data = GraphData.to_dict(
graph_data, graph_dict)
return graphs, data
@staticmethod
def _construct_graphs(node_dict, edge_dict):
graph_dict = {}
for graph_id in node_dict:
if graph_id not in edge_dict:
edge_dict[graph_id][('_V', '_E', '_V')] = {'edges': ([], [])}
graph = dgl_heterograph({etype: edata['edges']
for etype, edata in edge_dict[graph_id].items()},
num_nodes_dict={ntype: len(ndata['mapping'])
for ntype, ndata in node_dict[graph_id].items()})
def assign_data(type, src_data, dst_data):
for key, value in src_data.items():
dst_data[type].data[key] = value
for type, data in node_dict[graph_id].items():
assign_data(type, data['data'], graph.nodes)
for (type), data in edge_dict[graph_id].items():
assign_data(type, data['data'], graph.edges)
graph_dict[graph_id] = graph
return graph_dict
class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It
1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','.
3. read data and infer data type directly, otherwise.
"""
def __call__(self, df: pd.DataFrame):
data = {}
for header in df:
if 'Unnamed' in header:
dgl_warning("Unamed column is found. Ignored...")
continue
dt = df[header].to_numpy().squeeze()
if len(dt) > 0 and isinstance(dt[0], str):
#probably consists of list of numeric values
dt = np.array([ast.literal_eval(row) for row in dt])
data[header] = dt
return data
...@@ -8,10 +8,6 @@ from collections import OrderedDict ...@@ -8,10 +8,6 @@ from collections import OrderedDict
import itertools import itertools
import abc import abc
import re import re
try:
import rdflib as rdf
except ImportError:
pass
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -115,7 +111,6 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -115,7 +111,6 @@ class RDFGraphDataset(DGLBuiltinDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
import rdflib as rdf
self._insert_reverse = insert_reverse self._insert_reverse = insert_reverse
self._print_every = print_every self._print_every = print_every
self._predict_category = predict_category self._predict_category = predict_category
...@@ -141,6 +136,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -141,6 +136,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
------- -------
Loaded rdf data Loaded rdf data
""" """
import rdflib as rdf
raw_rdf_graphs = [] raw_rdf_graphs = []
for _, filename in enumerate(os.listdir(root_path)): for _, filename in enumerate(os.listdir(root_path)):
fmt = None fmt = None
...@@ -674,6 +670,7 @@ class AIFBDataset(RDFGraphDataset): ...@@ -674,6 +670,7 @@ class AIFBDataset(RDFGraphDataset):
return super(AIFBDataset, self).__len__() return super(AIFBDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal") return Entity(e_id=str(term), cls="_Literal")
if isinstance(term, rdf.BNode): if isinstance(term, rdf.BNode):
...@@ -855,6 +852,7 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -855,6 +852,7 @@ class MUTAGDataset(RDFGraphDataset):
return super(MUTAGDataset, self).__len__() return super(MUTAGDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal") return Entity(e_id=str(term), cls="_Literal")
elif isinstance(term, rdf.BNode): elif isinstance(term, rdf.BNode):
...@@ -1052,6 +1050,7 @@ class BGSDataset(RDFGraphDataset): ...@@ -1052,6 +1050,7 @@ class BGSDataset(RDFGraphDataset):
return super(BGSDataset, self).__len__() return super(BGSDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return None return None
elif isinstance(term, rdf.BNode): elif isinstance(term, rdf.BNode):
...@@ -1247,6 +1246,7 @@ class AMDataset(RDFGraphDataset): ...@@ -1247,6 +1246,7 @@ class AMDataset(RDFGraphDataset):
return super(AMDataset, self).__len__() return super(AMDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return None return None
elif isinstance(term, rdf.BNode): elif isinstance(term, rdf.BNode):
......
...@@ -8,7 +8,6 @@ import pandas as pd ...@@ -8,7 +8,6 @@ import pandas as pd
import yaml import yaml
import pytest import pytest
import dgl.data as data import dgl.data as data
import dgl.data.csv_dataset as csv_ds
from dgl import DGLError from dgl import DGLError
...@@ -164,6 +163,7 @@ def test_extract_archive(): ...@@ -164,6 +163,7 @@ def test_extract_archive():
def _test_construct_graphs_homo(): def _test_construct_graphs_homo():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1 # node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1
num_nodes = 100 num_nodes = 100
num_edges = 1000 num_edges = 1000
...@@ -179,13 +179,13 @@ def _test_construct_graphs_homo(): ...@@ -179,13 +179,13 @@ def _test_construct_graphs_homo():
_, u_indices = np.unique(node_ids, return_index=True) _, u_indices = np.unique(node_ids, return_index=True)
ndata = {'feat': t_ndata['feat'][u_indices], ndata = {'feat': t_ndata['feat'][u_indices],
'label': t_ndata['label'][u_indices]} 'label': t_ndata['label'][u_indices]}
node_data = csv_ds.NodeData(node_ids, t_ndata) node_data = NodeData(node_ids, t_ndata)
src_ids = np.random.choice(node_ids, size=num_edges) src_ids = np.random.choice(node_ids, size=num_edges)
dst_ids = np.random.choice(node_ids, size=num_edges) dst_ids = np.random.choice(node_ids, size=num_edges)
edata = {'feat': np.random.rand( edata = {'feat': np.random.rand(
num_edges, num_dims), 'label': np.random.randint(2, size=num_edges)} num_edges, num_dims), 'label': np.random.randint(2, size=num_edges)}
edge_data = csv_ds.EdgeData(src_ids, dst_ids, edata) edge_data = EdgeData(src_ids, dst_ids, edata)
graphs, data_dict = csv_ds.DGLGraphConstructor.construct_graphs( graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data) node_data, edge_data)
assert len(graphs) == 1 assert len(graphs) == 1
assert len(data_dict) == 0 assert len(data_dict) == 0
...@@ -203,6 +203,7 @@ def _test_construct_graphs_homo(): ...@@ -203,6 +203,7 @@ def _test_construct_graphs_homo():
def _test_construct_graphs_hetero(): def _test_construct_graphs_hetero():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1 # node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1
num_nodes = 100 num_nodes = 100
num_edges = 1000 num_edges = 1000
...@@ -223,7 +224,7 @@ def _test_construct_graphs_hetero(): ...@@ -223,7 +224,7 @@ def _test_construct_graphs_hetero():
_, u_indices = np.unique(node_ids, return_index=True) _, u_indices = np.unique(node_ids, return_index=True)
ndata = {'feat': t_ndata['feat'][u_indices], ndata = {'feat': t_ndata['feat'][u_indices],
'label': t_ndata['label'][u_indices]} 'label': t_ndata['label'][u_indices]}
node_data.append(csv_ds.NodeData(node_ids, t_ndata, type=ntype)) node_data.append(NodeData(node_ids, t_ndata, type=ntype))
node_ids_dict[ntype] = node_ids node_ids_dict[ntype] = node_ids
ndata_dict[ntype] = ndata ndata_dict[ntype] = ndata
etypes = [('user', 'follow', 'user'), ('user', 'like', 'item')] etypes = [('user', 'follow', 'user'), ('user', 'like', 'item')]
...@@ -234,10 +235,10 @@ def _test_construct_graphs_hetero(): ...@@ -234,10 +235,10 @@ def _test_construct_graphs_hetero():
dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges) dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
edata = {'feat': np.random.rand( edata = {'feat': np.random.rand(
num_edges, num_dims), 'label': np.random.randint(2, size=num_edges)} num_edges, num_dims), 'label': np.random.randint(2, size=num_edges)}
edge_data.append(csv_ds.EdgeData(src_ids, dst_ids, edata, edge_data.append(EdgeData(src_ids, dst_ids, edata,
type=(src_type, e_type, dst_type))) type=(src_type, e_type, dst_type)))
edata_dict[(src_type, e_type, dst_type)] = edata edata_dict[(src_type, e_type, dst_type)] = edata
graphs, data_dict = csv_ds.DGLGraphConstructor.construct_graphs( graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data) node_data, edge_data)
assert len(graphs) == 1 assert len(graphs) == 1
assert len(data_dict) == 0 assert len(data_dict) == 0
...@@ -259,6 +260,7 @@ def _test_construct_graphs_hetero(): ...@@ -259,6 +260,7 @@ def _test_construct_graphs_hetero():
def _test_construct_graphs_multiple(): def _test_construct_graphs_multiple():
from dgl.data.csv_dataset_base import NodeData, EdgeData, GraphData, DGLGraphConstructor
num_nodes = 100 num_nodes = 100
num_edges = 1000 num_edges = 1000
num_graphs = 10 num_graphs = 10
...@@ -283,14 +285,14 @@ def _test_construct_graphs_multiple(): ...@@ -283,14 +285,14 @@ def _test_construct_graphs_multiple():
egraph_ids = np.append(egraph_ids, np.full(num_edges, i)) egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
ndata = {'feat': np.random.rand(num_nodes*num_graphs, num_dims), ndata = {'feat': np.random.rand(num_nodes*num_graphs, num_dims),
'label': np.random.randint(2, size=num_nodes*num_graphs)} 'label': np.random.randint(2, size=num_nodes*num_graphs)}
node_data = csv_ds.NodeData(node_ids, ndata, graph_id=ngraph_ids) node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
edata = {'feat': np.random.rand( edata = {'feat': np.random.rand(
num_edges*num_graphs, num_dims), 'label': np.random.randint(2, size=num_edges*num_graphs)} num_edges*num_graphs, num_dims), 'label': np.random.randint(2, size=num_edges*num_graphs)}
edge_data = csv_ds.EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids) edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
gdata = {'feat': np.random.rand(num_graphs, num_dims), gdata = {'feat': np.random.rand(num_graphs, num_dims),
'label': np.random.randint(2, size=num_graphs)} 'label': np.random.randint(2, size=num_graphs)}
graph_data = csv_ds.GraphData(np.arange(num_graphs), gdata) graph_data = GraphData(np.arange(num_graphs), gdata)
graphs, data_dict = csv_ds.DGLGraphConstructor.construct_graphs( graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data) node_data, edge_data, graph_data)
assert len(graphs) == num_graphs assert len(graphs) == num_graphs
assert len(data_dict) == len(gdata) assert len(data_dict) == len(gdata)
...@@ -313,10 +315,10 @@ def _test_construct_graphs_multiple(): ...@@ -313,10 +315,10 @@ def _test_construct_graphs_multiple():
assert_data(edata, g.edata, num_edges) assert_data(edata, g.edata, num_edges)
# Graph IDs found in node/edge CSV but not in graph CSV # Graph IDs found in node/edge CSV but not in graph CSV
graph_data = csv_ds.GraphData(np.arange(num_graphs-2), {}) graph_data = GraphData(np.arange(num_graphs-2), {})
expect_except = False expect_except = False
try: try:
_, _ = csv_ds.DGLGraphConstructor.construct_graphs( _, _ = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data) node_data, edge_data, graph_data)
except: except:
expect_except = True expect_except = True
...@@ -324,6 +326,7 @@ def _test_construct_graphs_multiple(): ...@@ -324,6 +326,7 @@ def _test_construct_graphs_multiple():
def _test_DefaultDataParser(): def _test_DefaultDataParser():
from dgl.data.csv_dataset_base import DefaultDataParser
# common csv # common csv
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv") csv_path = os.path.join(test_dir, "nodes.csv")
...@@ -337,7 +340,7 @@ def _test_DefaultDataParser(): ...@@ -337,7 +340,7 @@ def _test_DefaultDataParser():
'feat': [line.tolist() for line in feat], 'feat': [line.tolist() for line in feat],
}) })
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
dp = csv_ds.DefaultDataParser() dp = DefaultDataParser()
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
dt = dp(df) dt = dp(df)
assert np.array_equal(node_id, dt['node_id']) assert np.array_equal(node_id, dt['node_id'])
...@@ -349,7 +352,7 @@ def _test_DefaultDataParser(): ...@@ -349,7 +352,7 @@ def _test_DefaultDataParser():
df = pd.DataFrame({'label': ['a', 'b', 'c'], df = pd.DataFrame({'label': ['a', 'b', 'c'],
}) })
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
dp = csv_ds.DefaultDataParser() dp = DefaultDataParser()
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
expect_except = False expect_except = False
try: try:
...@@ -363,13 +366,14 @@ def _test_DefaultDataParser(): ...@@ -363,13 +366,14 @@ def _test_DefaultDataParser():
df = pd.DataFrame({'label': [1, 2, 3], df = pd.DataFrame({'label': [1, 2, 3],
}) })
df.to_csv(csv_path) df.to_csv(csv_path)
dp = csv_ds.DefaultDataParser() dp = DefaultDataParser()
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
dt = dp(df) dt = dp(df)
assert len(dt) == 1 assert len(dt) == 1
def _test_load_yaml_with_sanity_check(): def _test_load_yaml_with_sanity_check():
from dgl.data.csv_dataset_base import load_yaml_with_sanity_check
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
yaml_path = os.path.join(test_dir, 'meta.yaml') yaml_path = os.path.join(test_dir, 'meta.yaml')
# workable but meaningless usually # workable but meaningless usually
...@@ -377,7 +381,7 @@ def _test_load_yaml_with_sanity_check(): ...@@ -377,7 +381,7 @@ def _test_load_yaml_with_sanity_check():
'node_data': [], 'edge_data': []} 'node_data': [], 'edge_data': []}
with open(yaml_path, 'w') as f: with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, sort_keys=False) yaml.dump(yaml_data, f, sort_keys=False)
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
assert meta.version == '1.0.0' assert meta.version == '1.0.0'
assert meta.dataset_name == 'default' assert meta.dataset_name == 'default'
assert meta.separator == ',' assert meta.separator == ','
...@@ -390,7 +394,7 @@ def _test_load_yaml_with_sanity_check(): ...@@ -390,7 +394,7 @@ def _test_load_yaml_with_sanity_check():
} }
with open(yaml_path, 'w') as f: with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, sort_keys=False) yaml.dump(yaml_data, f, sort_keys=False)
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
for ndata in meta.node_data: for ndata in meta.node_data:
assert ndata.file_name == 'nodes.csv' assert ndata.file_name == 'nodes.csv'
assert ndata.ntype == '_V' assert ndata.ntype == '_V'
...@@ -411,7 +415,7 @@ def _test_load_yaml_with_sanity_check(): ...@@ -411,7 +415,7 @@ def _test_load_yaml_with_sanity_check():
} }
with open(yaml_path, 'w') as f: with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, sort_keys=False) yaml.dump(yaml_data, f, sort_keys=False)
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
assert len(meta.node_data) == 1 assert len(meta.node_data) == 1
ndata = meta.node_data[0] ndata = meta.node_data[0]
assert ndata.ntype == 'user' assert ndata.ntype == 'user'
...@@ -436,7 +440,7 @@ def _test_load_yaml_with_sanity_check(): ...@@ -436,7 +440,7 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(ydata, f, sort_keys=False) yaml.dump(ydata, f, sort_keys=False)
expect_except = False expect_except = False
try: try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
except: except:
expect_except = True expect_except = True
assert expect_except assert expect_except
...@@ -448,7 +452,7 @@ def _test_load_yaml_with_sanity_check(): ...@@ -448,7 +452,7 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(yaml_data, f, sort_keys=False) yaml.dump(yaml_data, f, sort_keys=False)
expect_except = False expect_except = False
try: try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
except DGLError: except DGLError:
expect_except = True expect_except = True
assert expect_except assert expect_except
...@@ -460,7 +464,7 @@ def _test_load_yaml_with_sanity_check(): ...@@ -460,7 +464,7 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(yaml_data, f, sort_keys=False) yaml.dump(yaml_data, f, sort_keys=False)
expect_except = False expect_except = False
try: try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
except DGLError: except DGLError:
expect_except = True expect_except = True
assert expect_except assert expect_except
...@@ -472,22 +476,23 @@ def _test_load_yaml_with_sanity_check(): ...@@ -472,22 +476,23 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(yaml_data, f, sort_keys=False) yaml.dump(yaml_data, f, sort_keys=False)
expect_except = False expect_except = False
try: try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path) meta = load_yaml_with_sanity_check(yaml_path)
except DGLError: except DGLError:
expect_except = True expect_except = True
assert expect_except assert expect_except
def _test_load_node_data_from_csv(): def _test_load_node_data_from_csv():
from dgl.data.csv_dataset_base import MetaNode, NodeData, DefaultDataParser
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
num_nodes = 100 num_nodes = 100
# minimum # minimum
df = pd.DataFrame({'node_id': np.arange(num_nodes)}) df = pd.DataFrame({'node_id': np.arange(num_nodes)})
csv_path = os.path.join(test_dir, 'nodes.csv') csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path) meta_node = MetaNode(file_name=csv_path)
node_data = csv_ds.NodeData.load_from_csv( node_data = NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser()) meta_node, DefaultDataParser())
assert np.array_equal(df['node_id'], node_data.id) assert np.array_equal(df['node_id'], node_data.id)
assert len(node_data.data) == 0 assert len(node_data.data) == 0
...@@ -496,9 +501,9 @@ def _test_load_node_data_from_csv(): ...@@ -496,9 +501,9 @@ def _test_load_node_data_from_csv():
'label': np.random.randint(3, size=num_nodes)}) 'label': np.random.randint(3, size=num_nodes)})
csv_path = os.path.join(test_dir, 'nodes.csv') csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path) meta_node = MetaNode(file_name=csv_path)
node_data = csv_ds.NodeData.load_from_csv( node_data = NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser()) meta_node, DefaultDataParser())
assert np.array_equal(df['node_id'], node_data.id) assert np.array_equal(df['node_id'], node_data.id)
assert len(node_data.data) == 1 assert len(node_data.data) == 1
assert np.array_equal(df['label'], node_data.data['label']) assert np.array_equal(df['label'], node_data.data['label'])
...@@ -510,9 +515,9 @@ def _test_load_node_data_from_csv(): ...@@ -510,9 +515,9 @@ def _test_load_node_data_from_csv():
3, size=num_nodes), 'graph_id': np.full(num_nodes, 1)}) 3, size=num_nodes), 'graph_id': np.full(num_nodes, 1)})
csv_path = os.path.join(test_dir, 'nodes.csv') csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path) meta_node = MetaNode(file_name=csv_path)
node_data = csv_ds.NodeData.load_from_csv( node_data = NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser()) meta_node, DefaultDataParser())
assert np.array_equal(df['node_id'], node_data.id) assert np.array_equal(df['node_id'], node_data.id)
assert len(node_data.data) == 1 assert len(node_data.data) == 1
assert np.array_equal(df['label'], node_data.data['label']) assert np.array_equal(df['label'], node_data.data['label'])
...@@ -523,17 +528,18 @@ def _test_load_node_data_from_csv(): ...@@ -523,17 +528,18 @@ def _test_load_node_data_from_csv():
df = pd.DataFrame({'label': np.random.randint(3, size=num_nodes)}) df = pd.DataFrame({'label': np.random.randint(3, size=num_nodes)})
csv_path = os.path.join(test_dir, 'nodes.csv') csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path) meta_node = MetaNode(file_name=csv_path)
expect_except = False expect_except = False
try: try:
csv_ds.NodeData.load_from_csv( NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser()) meta_node, DefaultDataParser())
except: except:
expect_except = True expect_except = True
assert expect_except assert expect_except
def _test_load_edge_data_from_csv(): def _test_load_edge_data_from_csv():
from dgl.data.csv_dataset_base import MetaEdge, EdgeData, DefaultDataParser
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
num_nodes = 100 num_nodes = 100
num_edges = 1000 num_edges = 1000
...@@ -543,9 +549,9 @@ def _test_load_edge_data_from_csv(): ...@@ -543,9 +549,9 @@ def _test_load_edge_data_from_csv():
}) })
csv_path = os.path.join(test_dir, 'edges.csv') csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path) meta_edge = MetaEdge(file_name=csv_path)
edge_data = csv_ds.EdgeData.load_from_csv( edge_data = EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser()) meta_edge, DefaultDataParser())
assert np.array_equal(df['src_id'], edge_data.src) assert np.array_equal(df['src_id'], edge_data.src)
assert np.array_equal(df['dst_id'], edge_data.dst) assert np.array_equal(df['dst_id'], edge_data.dst)
assert len(edge_data.data) == 0 assert len(edge_data.data) == 0
...@@ -556,9 +562,9 @@ def _test_load_edge_data_from_csv(): ...@@ -556,9 +562,9 @@ def _test_load_edge_data_from_csv():
'label': np.random.randint(3, size=num_edges)}) 'label': np.random.randint(3, size=num_edges)})
csv_path = os.path.join(test_dir, 'edges.csv') csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path) meta_edge = MetaEdge(file_name=csv_path)
edge_data = csv_ds.EdgeData.load_from_csv( edge_data = EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser()) meta_edge, DefaultDataParser())
assert np.array_equal(df['src_id'], edge_data.src) assert np.array_equal(df['src_id'], edge_data.src)
assert np.array_equal(df['dst_id'], edge_data.dst) assert np.array_equal(df['dst_id'], edge_data.dst)
assert len(edge_data.data) == 1 assert len(edge_data.data) == 1
...@@ -574,9 +580,9 @@ def _test_load_edge_data_from_csv(): ...@@ -574,9 +580,9 @@ def _test_load_edge_data_from_csv():
'label': np.random.randint(3, size=num_edges)}) 'label': np.random.randint(3, size=num_edges)})
csv_path = os.path.join(test_dir, 'edges.csv') csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path) meta_edge = MetaEdge(file_name=csv_path)
edge_data = csv_ds.EdgeData.load_from_csv( edge_data = EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser()) meta_edge, DefaultDataParser())
assert np.array_equal(df['src_id'], edge_data.src) assert np.array_equal(df['src_id'], edge_data.src)
assert np.array_equal(df['dst_id'], edge_data.dst) assert np.array_equal(df['dst_id'], edge_data.dst)
assert len(edge_data.data) == 2 assert len(edge_data.data) == 2
...@@ -590,11 +596,11 @@ def _test_load_edge_data_from_csv(): ...@@ -590,11 +596,11 @@ def _test_load_edge_data_from_csv():
}) })
csv_path = os.path.join(test_dir, 'edges.csv') csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path) meta_edge = MetaEdge(file_name=csv_path)
expect_except = False expect_except = False
try: try:
csv_ds.EdgeData.load_from_csv( EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser()) meta_edge, DefaultDataParser())
except DGLError: except DGLError:
expect_except = True expect_except = True
assert expect_except assert expect_except
...@@ -602,26 +608,27 @@ def _test_load_edge_data_from_csv(): ...@@ -602,26 +608,27 @@ def _test_load_edge_data_from_csv():
}) })
csv_path = os.path.join(test_dir, 'edges.csv') csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path) meta_edge = MetaEdge(file_name=csv_path)
expect_except = False expect_except = False
try: try:
csv_ds.EdgeData.load_from_csv( EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser()) meta_edge, DefaultDataParser())
except DGLError: except DGLError:
expect_except = True expect_except = True
assert expect_except assert expect_except
def _test_load_graph_data_from_csv(): def _test_load_graph_data_from_csv():
from dgl.data.csv_dataset_base import MetaGraph, GraphData, DefaultDataParser
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
num_graphs = 100 num_graphs = 100
# minimum # minimum
df = pd.DataFrame({'graph_id': np.arange(num_graphs)}) df = pd.DataFrame({'graph_id': np.arange(num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv') csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path) meta_graph = MetaGraph(file_name=csv_path)
graph_data = csv_ds.GraphData.load_from_csv( graph_data = GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser()) meta_graph, DefaultDataParser())
assert np.array_equal(df['graph_id'], graph_data.graph_id) assert np.array_equal(df['graph_id'], graph_data.graph_id)
assert len(graph_data.data) == 0 assert len(graph_data.data) == 0
...@@ -630,9 +637,9 @@ def _test_load_graph_data_from_csv(): ...@@ -630,9 +637,9 @@ def _test_load_graph_data_from_csv():
'label': np.random.randint(3, size=num_graphs)}) 'label': np.random.randint(3, size=num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv') csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path) meta_graph = MetaGraph(file_name=csv_path)
graph_data = csv_ds.GraphData.load_from_csv( graph_data = GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser()) meta_graph, DefaultDataParser())
assert np.array_equal(df['graph_id'], graph_data.graph_id) assert np.array_equal(df['graph_id'], graph_data.graph_id)
assert len(graph_data.data) == 1 assert len(graph_data.data) == 1
assert np.array_equal(df['label'], graph_data.data['label']) assert np.array_equal(df['label'], graph_data.data['label'])
...@@ -643,9 +650,9 @@ def _test_load_graph_data_from_csv(): ...@@ -643,9 +650,9 @@ def _test_load_graph_data_from_csv():
'label': np.random.randint(3, size=num_graphs)}) 'label': np.random.randint(3, size=num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv') csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path) meta_graph = MetaGraph(file_name=csv_path)
graph_data = csv_ds.GraphData.load_from_csv( graph_data = GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser()) meta_graph, DefaultDataParser())
assert np.array_equal(df['graph_id'], graph_data.graph_id) assert np.array_equal(df['graph_id'], graph_data.graph_id)
assert len(graph_data.data) == 2 assert len(graph_data.data) == 2
assert np.array_equal(df['feat'], graph_data.data['feat']) assert np.array_equal(df['feat'], graph_data.data['feat'])
...@@ -655,11 +662,11 @@ def _test_load_graph_data_from_csv(): ...@@ -655,11 +662,11 @@ def _test_load_graph_data_from_csv():
df = pd.DataFrame({'label': np.random.randint(3, size=num_graphs)}) df = pd.DataFrame({'label': np.random.randint(3, size=num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv') csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path) meta_graph = MetaGraph(file_name=csv_path)
expect_except = False expect_except = False
try: try:
csv_ds.GraphData.load_from_csv( GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser()) meta_graph, DefaultDataParser())
except DGLError: except DGLError:
expect_except = True expect_except = True
assert expect_except assert expect_except
...@@ -716,7 +723,7 @@ def _test_DGLCSVDataset_single(): ...@@ -716,7 +723,7 @@ def _test_DGLCSVDataset_single():
# remove original node data file to verify reload from cached files # remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0) os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0) assert not os.path.exists(nodes_csv_path_0)
csv_dataset = csv_ds.DGLCSVDataset( csv_dataset = data.DGLCSVDataset(
test_dir, force_reload=force_reload) test_dir, force_reload=force_reload)
assert len(csv_dataset) == 1 assert len(csv_dataset) == 1
g = csv_dataset[0] g = csv_dataset[0]
...@@ -799,7 +806,7 @@ def _test_DGLCSVDataset_multiple(): ...@@ -799,7 +806,7 @@ def _test_DGLCSVDataset_multiple():
# remove original node data file to verify reload from cached files # remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0) os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0) assert not os.path.exists(nodes_csv_path_0)
csv_dataset = csv_ds.DGLCSVDataset( csv_dataset = data.DGLCSVDataset(
test_dir, force_reload=force_reload) test_dir, force_reload=force_reload)
assert len(csv_dataset) == num_graphs assert len(csv_dataset) == num_graphs
assert csv_dataset.has_cache() assert csv_dataset.has_cache()
...@@ -885,7 +892,7 @@ def _test_DGLCSVDataset_customized_data_parser(): ...@@ -885,7 +892,7 @@ def _test_DGLCSVDataset_customized_data_parser():
data[header] = dt data[header] = dt
return data return data
# load CSVDataset with customized node/edge/graph_data_parser # load CSVDataset with customized node/edge/graph_data_parser
csv_dataset = csv_ds.DGLCSVDataset( csv_dataset = data.DGLCSVDataset(
test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser()) test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser())
assert len(csv_dataset) == num_graphs assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1 assert len(csv_dataset.data) == 1
...@@ -906,10 +913,11 @@ def _test_DGLCSVDataset_customized_data_parser(): ...@@ -906,10 +913,11 @@ def _test_DGLCSVDataset_customized_data_parser():
def _test_NodeEdgeGraphData(): def _test_NodeEdgeGraphData():
from dgl.data.csv_dataset_base import NodeData, EdgeData, GraphData
# NodeData basics # NodeData basics
num_nodes = 100 num_nodes = 100
node_ids = np.arange(num_nodes, dtype=np.float) node_ids = np.arange(num_nodes, dtype=np.float)
ndata = csv_ds.NodeData(node_ids, {}) ndata = NodeData(node_ids, {})
assert ndata.id.dtype == np.int64 assert ndata.id.dtype == np.int64
assert np.array_equal(ndata.id, node_ids.astype(np.int64)) assert np.array_equal(ndata.id, node_ids.astype(np.int64))
assert len(ndata.data) == 0 assert len(ndata.data) == 0
...@@ -918,7 +926,7 @@ def _test_NodeEdgeGraphData(): ...@@ -918,7 +926,7 @@ def _test_NodeEdgeGraphData():
# NodeData more # NodeData more
data = {'feat': np.random.rand(num_nodes, 3)} data = {'feat': np.random.rand(num_nodes, 3)}
graph_id = np.arange(num_nodes) graph_id = np.arange(num_nodes)
ndata = csv_ds.NodeData(node_ids, data, type='user', graph_id=graph_id) ndata = NodeData(node_ids, data, type='user', graph_id=graph_id)
assert ndata.type == 'user' assert ndata.type == 'user'
assert np.array_equal(ndata.graph_id, graph_id) assert np.array_equal(ndata.graph_id, graph_id)
assert len(ndata.data) == len(data) assert len(ndata.data) == len(data)
...@@ -928,7 +936,7 @@ def _test_NodeEdgeGraphData(): ...@@ -928,7 +936,7 @@ def _test_NodeEdgeGraphData():
# NodeData except # NodeData except
expect_except = False expect_except = False
try: try:
csv_ds.NodeData(np.arange(num_nodes), {'feat': np.random.rand( NodeData(np.arange(num_nodes), {'feat': np.random.rand(
num_nodes+1, 3)}, graph_id=np.arange(num_nodes-1)) num_nodes+1, 3)}, graph_id=np.arange(num_nodes-1))
except: except:
expect_except = True expect_except = True
...@@ -939,7 +947,7 @@ def _test_NodeEdgeGraphData(): ...@@ -939,7 +947,7 @@ def _test_NodeEdgeGraphData():
num_edges = 1000 num_edges = 1000
src_ids = np.random.randint(num_nodes, size=num_edges) src_ids = np.random.randint(num_nodes, size=num_edges)
dst_ids = np.random.randint(num_nodes, size=num_edges) dst_ids = np.random.randint(num_nodes, size=num_edges)
edata = csv_ds.EdgeData(src_ids, dst_ids, {}) edata = EdgeData(src_ids, dst_ids, {})
assert np.array_equal(edata.src, src_ids) assert np.array_equal(edata.src, src_ids)
assert np.array_equal(edata.dst, dst_ids) assert np.array_equal(edata.dst, dst_ids)
assert edata.type == ('_V', '_E', '_V') assert edata.type == ('_V', '_E', '_V')
...@@ -951,7 +959,7 @@ def _test_NodeEdgeGraphData(): ...@@ -951,7 +959,7 @@ def _test_NodeEdgeGraphData():
data = {'feat': np.random.rand(num_edges, 3)} data = {'feat': np.random.rand(num_edges, 3)}
etype = ('user', 'like', 'item') etype = ('user', 'like', 'item')
graph_ids = np.arange(num_edges) graph_ids = np.arange(num_edges)
edata = csv_ds.EdgeData(src_ids, dst_ids, data, edata = EdgeData(src_ids, dst_ids, data,
type=etype, graph_id=graph_ids) type=etype, graph_id=graph_ids)
assert edata.src.dtype == np.int64 assert edata.src.dtype == np.int64
assert edata.dst.dtype == np.int64 assert edata.dst.dtype == np.int64
...@@ -966,7 +974,7 @@ def _test_NodeEdgeGraphData(): ...@@ -966,7 +974,7 @@ def _test_NodeEdgeGraphData():
# EdgeData except # EdgeData except
expect_except = False expect_except = False
try: try:
csv_ds.EdgeData(np.arange(num_edges), np.arange( EdgeData(np.arange(num_edges), np.arange(
num_edges+1), {'feat': np.random.rand(num_edges-1, 3)}, graph_id=np.arange(num_edges+2)) num_edges+1), {'feat': np.random.rand(num_edges-1, 3)}, graph_id=np.arange(num_edges+2))
except: except:
expect_except = True expect_except = True
...@@ -975,13 +983,13 @@ def _test_NodeEdgeGraphData(): ...@@ -975,13 +983,13 @@ def _test_NodeEdgeGraphData():
# GraphData basics # GraphData basics
num_graphs = 10 num_graphs = 10
graph_ids = np.arange(num_graphs) graph_ids = np.arange(num_graphs)
gdata = csv_ds.GraphData(graph_ids, {}) gdata = GraphData(graph_ids, {})
assert np.array_equal(gdata.graph_id, graph_ids) assert np.array_equal(gdata.graph_id, graph_ids)
assert len(gdata.data) == 0 assert len(gdata.data) == 0
# GraphData more # GraphData more
graph_ids = np.arange(num_graphs).astype(np.float) graph_ids = np.arange(num_graphs).astype(np.float)
data = {'feat': np.random.rand(num_graphs, 3)} data = {'feat': np.random.rand(num_graphs, 3)}
gdata = csv_ds.GraphData(graph_ids, data) gdata = GraphData(graph_ids, data)
assert gdata.graph_id.dtype == np.int64 assert gdata.graph_id.dtype == np.int64
assert np.array_equal(gdata.graph_id, graph_ids) assert np.array_equal(gdata.graph_id, graph_ids)
assert len(gdata.data) == len(data) assert len(gdata.data) == len(data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment