Unverified Commit 9ec9df57 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Fix] check and load dependencies when needed (#3655)

* [Fix] check and load dependencies when needed

* refine rdflib import
parent 77f4287a
......@@ -29,7 +29,7 @@ from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset
def register_data_args(parser):
parser.add_argument(
......
import os
import yaml
from yaml.loader import SafeLoader
import pandas as pd
import numpy as np
from typing import List, Optional
import pydantic as dt
from .dgl_dataset import DGLDataset
from ..convert import heterograph as dgl_heterograph
from .. import backend as F
from .utils import save_graphs, load_graphs
from ..base import dgl_warning, DGLError
import abc
import ast
from typing import Callable
class MetaNode(dt.BaseModel):
""" Class of node_data in YAML. Internal use only. """
file_name: str
ntype: Optional[str] = '_V'
graph_id_field: Optional[str] = 'graph_id'
node_id_field: Optional[str] = 'node_id'
class MetaEdge(dt.BaseModel):
""" Class of edge_data in YAML. Internal use only. """
file_name: str
etype: Optional[List[str]] = ['_V', '_E', '_V']
graph_id_field: Optional[str] = 'graph_id'
src_id_field: Optional[str] = 'src_id'
dst_id_field: Optional[str] = 'dst_id'
class MetaGraph(dt.BaseModel):
""" Class of graph_data in YAML. Internal use only. """
file_name: str
graph_id_field: Optional[str] = 'graph_id'
class MetaYaml(dt.BaseModel):
""" Class of YAML. Internal use only. """
version: Optional[str] = '1.0.0'
dataset_name: str
separator: Optional[str] = ','
node_data: List[MetaNode]
edge_data: List[MetaEdge]
graph_data: Optional[MetaGraph] = None
def load_yaml_with_sanity_check(yaml_file):
""" Load yaml and do sanity check. Internal use only. """
with open(yaml_file) as f:
yaml_data = yaml.load(f, Loader=SafeLoader)
try:
meta_yaml = MetaYaml(**yaml_data)
except dt.ValidationError as e:
print(
"Details of pydantic.ValidationError:\n{}".format(e.json()))
raise DGLError(
"Validation Error for YAML fields. Details are shown above.")
if meta_yaml.version != '1.0.0':
raise DGLError("Invalid CSVDataset version {}. Supported versions: '1.0.0'".format(
meta_yaml.version))
ntypes = [meta.ntype for meta in meta_yaml.node_data]
if len(ntypes) > len(set(ntypes)):
raise DGLError(
"Each node CSV file must have a unique node type name, but found duplicate node type: {}.".format(ntypes))
etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data]
if len(etypes) > len(set(etypes)):
raise DGLError(
"Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.".format(etypes))
return meta_yaml
def _validate_data_length(data_dict):
len_dict = {k: len(v) for k, v in data_dict.items()}
lst = list(len_dict.values())
res = lst.count(lst[0]) == len(lst)
if not res:
raise DGLError(
"All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict)))
class BaseData:
""" Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """
@staticmethod
def read_csv(file_name, base_dir, separator):
csv_path = file_name
if base_dir is not None:
csv_path = os.path.join(base_dir, csv_path)
return pd.read_csv(csv_path, sep=separator)
@staticmethod
def pop_from_dataframe(df: pd.DataFrame, item: str):
ret = None
try:
ret = df.pop(item).to_numpy().squeeze()
except KeyError:
pass
return ret
class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """
def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if node_ids is None:
raise DGLError("Missing node id field [{}] in file [{}].".format(
meta.node_id_field, meta.file_name))
ntype = meta.ntype
ndata = data_parser(df)
return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)
@staticmethod
def to_dict(node_data: List['NodeData']) -> dict:
# node_ids could be arbitrary numeric values, namely non-sorted, duplicated, not labeled from 0 to num_nodes-1
node_dict = {}
for n_data in node_data:
graph_ids = np.unique(n_data.graph_id)
for graph_id in graph_ids:
idx = n_data.graph_id == graph_id
ids = n_data.id[idx]
u_ids, u_indices = np.unique(ids, return_index=True)
if len(ids) > len(u_ids):
dgl_warning(
"There exist duplicated ids and only the first ones are kept.")
if graph_id not in node_dict:
node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i,
index in enumerate(ids[u_indices])},
'data': {k: F.tensor(v[idx][u_indices])
for k, v in n_data.data.items()}}
return node_dict
class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """
def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64)
self.dst = np.array(dst_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field)
if src_ids is None:
raise DGLError("Missing src id field [{}] in file [{}].".format(
meta.src_id_field, meta.file_name))
dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field)
if dst_ids is None:
raise DGLError("Missing dst id field [{}] in file [{}].".format(
meta.dst_id_field, meta.file_name))
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
etype = tuple(meta.etype)
edata = data_parser(df)
return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids)
@staticmethod
def to_dict(edge_data: List['EdgeData'], node_dict: dict) -> dict:
edge_dict = {}
for e_data in edge_data:
(src_type, e_type, dst_type) = e_data.type
graph_ids = np.unique(e_data.graph_id)
for graph_id in graph_ids:
if graph_id in edge_dict and e_data.type in edge_dict[graph_id]:
raise DGLError(f"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData.")
idx = e_data.graph_id == graph_id
src_mapping = node_dict[graph_id][src_type]['mapping']
dst_mapping = node_dict[graph_id][dst_type]['mapping']
src_ids = [src_mapping[index] for index in e_data.src[idx]]
dst_ids = [dst_mapping[index] for index in e_data.dst[idx]]
if graph_id not in edge_dict:
edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (F.tensor(src_ids), F.tensor(dst_ids)),
'data': {k: F.tensor(v[idx])
for k, v in e_data.data.items()}}
return edge_dict
class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """
def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64)
self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaGraph, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if graph_ids is None:
raise DGLError("Missing graph id field [{}] in file [{}].".format(
meta.graph_id_field, meta.file_name))
gdata = data_parser(df)
return GraphData(graph_ids, gdata)
@staticmethod
def to_dict(graph_data: 'GraphData', graphs_dict: dict) -> dict:
missing_ids = np.setdiff1d(
np.array(list(graphs_dict.keys())), graph_data.graph_id)
if len(missing_ids) > 0:
raise DGLError(
"Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format(missing_ids))
graph_ids = graph_data.graph_id
graphs = []
for graph_id in graph_ids:
if graph_id not in graphs_dict:
graphs_dict[graph_id] = dgl_heterograph(
{('_V', '_E', '_V'): ([], [])})
for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id])
data = {k: F.tensor(v) for k, v in graph_data.data.items()}
return graphs, data
class DGLGraphConstructor:
""" Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only. """
@staticmethod
def construct_graphs(node_data, edge_data, graph_data=None):
if not isinstance(node_data, list):
node_data = [node_data]
if not isinstance(edge_data, list):
edge_data = [edge_data]
node_dict = NodeData.to_dict(node_data)
edge_dict = EdgeData.to_dict(edge_data, node_dict)
graph_dict = DGLGraphConstructor._construct_graphs(
node_dict, edge_dict)
if graph_data is None:
graph_data = GraphData(np.full(1, 0), {})
graphs, data = GraphData.to_dict(
graph_data, graph_dict)
return graphs, data
@staticmethod
def _construct_graphs(node_dict, edge_dict):
graph_dict = {}
for graph_id in node_dict:
if graph_id not in edge_dict:
edge_dict[graph_id][('_V', '_E', '_V')] = {'edges': ([], [])}
graph = dgl_heterograph({etype: edata['edges']
for etype, edata in edge_dict[graph_id].items()},
num_nodes_dict={ntype: len(ndata['mapping'])
for ntype, ndata in node_dict[graph_id].items()})
def assign_data(type, src_data, dst_data):
for key, value in src_data.items():
dst_data[type].data[key] = value
for type, data in node_dict[graph_id].items():
assign_data(type, data['data'], graph.nodes)
for (type), data in edge_dict[graph_id].items():
assign_data(type, data['data'], graph.edges)
graph_dict[graph_id] = graph
return graph_dict
class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It
1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','.
3. read data and infer data type directly, otherwise.
"""
def __call__(self, df: pd.DataFrame):
data = {}
for header in df:
if 'Unnamed' in header:
dgl_warning("Unamed column is found. Ignored...")
continue
dt = df[header].to_numpy().squeeze()
if len(dt) > 0 and isinstance(dt[0], str):
#probably consists of list of numeric values
dt = np.array([ast.literal_eval(row) for row in dt])
data[header] = dt
return data
from ..base import DGLError
class DGLCSVDataset(DGLDataset):
......@@ -337,6 +47,7 @@ class DGLCSVDataset(DGLDataset):
META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None
self.data = None
self.node_data_parser = {} if node_data_parser is None else node_data_parser
......@@ -352,9 +63,11 @@ class DGLCSVDataset(DGLDataset):
super().__init__(ds_name, raw_dir=os.path.dirname(
meta_yaml_path), force_reload=force_reload, verbose=verbose)
def process(self):
"""Parse node/edge data from CSV files and construct DGL.Graphs
"""
from .csv_dataset_base import NodeData, EdgeData, GraphData, DGLGraphConstructor
meta_yaml = self.meta_yaml
base_dir = self.raw_dir
node_data = []
......
import os
import numpy as np
from typing import List, Optional, Callable
from .. import backend as F
from ..convert import heterograph as dgl_heterograph
from ..base import dgl_warning, DGLError
import ast
import pydantic as dt
import pandas as pd
import yaml
class MetaNode(dt.BaseModel):
""" Class of node_data in YAML. Internal use only. """
file_name: str
ntype: Optional[str] = '_V'
graph_id_field: Optional[str] = 'graph_id'
node_id_field: Optional[str] = 'node_id'
class MetaEdge(dt.BaseModel):
""" Class of edge_data in YAML. Internal use only. """
file_name: str
etype: Optional[List[str]] = ['_V', '_E', '_V']
graph_id_field: Optional[str] = 'graph_id'
src_id_field: Optional[str] = 'src_id'
dst_id_field: Optional[str] = 'dst_id'
class MetaGraph(dt.BaseModel):
""" Class of graph_data in YAML. Internal use only. """
file_name: str
graph_id_field: Optional[str] = 'graph_id'
class MetaYaml(dt.BaseModel):
""" Class of YAML. Internal use only. """
version: Optional[str] = '1.0.0'
dataset_name: str
separator: Optional[str] = ','
node_data: List[MetaNode]
edge_data: List[MetaEdge]
graph_data: Optional[MetaGraph] = None
def load_yaml_with_sanity_check(yaml_file):
""" Load yaml and do sanity check. Internal use only. """
with open(yaml_file) as f:
yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
try:
meta_yaml = MetaYaml(**yaml_data)
except dt.ValidationError as e:
print(
"Details of pydantic.ValidationError:\n{}".format(e.json()))
raise DGLError(
"Validation Error for YAML fields. Details are shown above.")
if meta_yaml.version != '1.0.0':
raise DGLError("Invalid CSVDataset version {}. Supported versions: '1.0.0'".format(
meta_yaml.version))
ntypes = [meta.ntype for meta in meta_yaml.node_data]
if len(ntypes) > len(set(ntypes)):
raise DGLError(
"Each node CSV file must have a unique node type name, but found duplicate node type: {}.".format(ntypes))
etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data]
if len(etypes) > len(set(etypes)):
raise DGLError(
"Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.".format(etypes))
return meta_yaml
def _validate_data_length(data_dict):
len_dict = {k: len(v) for k, v in data_dict.items()}
lst = list(len_dict.values())
res = lst.count(lst[0]) == len(lst)
if not res:
raise DGLError(
"All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict)))
class BaseData:
""" Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """
@staticmethod
def read_csv(file_name, base_dir, separator):
csv_path = file_name
if base_dir is not None:
csv_path = os.path.join(base_dir, csv_path)
return pd.read_csv(csv_path, sep=separator)
@staticmethod
def pop_from_dataframe(df: pd.DataFrame, item: str):
ret = None
try:
ret = df.pop(item).to_numpy().squeeze()
except KeyError:
pass
return ret
class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """
def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if node_ids is None:
raise DGLError("Missing node id field [{}] in file [{}].".format(
meta.node_id_field, meta.file_name))
ntype = meta.ntype
ndata = data_parser(df)
return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)
@staticmethod
def to_dict(node_data: List['NodeData']) -> dict:
# node_ids could be arbitrary numeric values, namely non-sorted, duplicated, not labeled from 0 to num_nodes-1
node_dict = {}
for n_data in node_data:
graph_ids = np.unique(n_data.graph_id)
for graph_id in graph_ids:
idx = n_data.graph_id == graph_id
ids = n_data.id[idx]
u_ids, u_indices = np.unique(ids, return_index=True)
if len(ids) > len(u_ids):
dgl_warning(
"There exist duplicated ids and only the first ones are kept.")
if graph_id not in node_dict:
node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i,
index in enumerate(ids[u_indices])},
'data': {k: F.tensor(v[idx][u_indices])
for k, v in n_data.data.items()}}
return node_dict
class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """
def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64)
self.dst = np.array(dst_id, dtype=np.int64)
self.data = data
self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field)
if src_ids is None:
raise DGLError("Missing src id field [{}] in file [{}].".format(
meta.src_id_field, meta.file_name))
dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field)
if dst_ids is None:
raise DGLError("Missing dst id field [{}] in file [{}].".format(
meta.dst_id_field, meta.file_name))
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
etype = tuple(meta.etype)
edata = data_parser(df)
return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids)
@staticmethod
def to_dict(edge_data: List['EdgeData'], node_dict: dict) -> dict:
edge_dict = {}
for e_data in edge_data:
(src_type, e_type, dst_type) = e_data.type
graph_ids = np.unique(e_data.graph_id)
for graph_id in graph_ids:
if graph_id in edge_dict and e_data.type in edge_dict[graph_id]:
raise DGLError(f"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData.")
idx = e_data.graph_id == graph_id
src_mapping = node_dict[graph_id][src_type]['mapping']
dst_mapping = node_dict[graph_id][dst_type]['mapping']
src_ids = [src_mapping[index] for index in e_data.src[idx]]
dst_ids = [dst_mapping[index] for index in e_data.dst[idx]]
if graph_id not in edge_dict:
edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (F.tensor(src_ids), F.tensor(dst_ids)),
'data': {k: F.tensor(v[idx])
for k, v in e_data.data.items()}}
return edge_dict
class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """
def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64)
self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data})
@staticmethod
def load_from_csv(meta: MetaGraph, data_parser: Callable, base_dir=None, separator=','):
df = BaseData.read_csv(meta.file_name, base_dir, separator)
graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
if graph_ids is None:
raise DGLError("Missing graph id field [{}] in file [{}].".format(
meta.graph_id_field, meta.file_name))
gdata = data_parser(df)
return GraphData(graph_ids, gdata)
@staticmethod
def to_dict(graph_data: 'GraphData', graphs_dict: dict) -> dict:
missing_ids = np.setdiff1d(
np.array(list(graphs_dict.keys())), graph_data.graph_id)
if len(missing_ids) > 0:
raise DGLError(
"Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format(missing_ids))
graph_ids = graph_data.graph_id
graphs = []
for graph_id in graph_ids:
if graph_id not in graphs_dict:
graphs_dict[graph_id] = dgl_heterograph(
{('_V', '_E', '_V'): ([], [])})
for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id])
data = {k: F.tensor(v) for k, v in graph_data.data.items()}
return graphs, data
class DGLGraphConstructor:
""" Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only. """
@staticmethod
def construct_graphs(node_data, edge_data, graph_data=None):
if not isinstance(node_data, list):
node_data = [node_data]
if not isinstance(edge_data, list):
edge_data = [edge_data]
node_dict = NodeData.to_dict(node_data)
edge_dict = EdgeData.to_dict(edge_data, node_dict)
graph_dict = DGLGraphConstructor._construct_graphs(
node_dict, edge_dict)
if graph_data is None:
graph_data = GraphData(np.full(1, 0), {})
graphs, data = GraphData.to_dict(
graph_data, graph_dict)
return graphs, data
@staticmethod
def _construct_graphs(node_dict, edge_dict):
graph_dict = {}
for graph_id in node_dict:
if graph_id not in edge_dict:
edge_dict[graph_id][('_V', '_E', '_V')] = {'edges': ([], [])}
graph = dgl_heterograph({etype: edata['edges']
for etype, edata in edge_dict[graph_id].items()},
num_nodes_dict={ntype: len(ndata['mapping'])
for ntype, ndata in node_dict[graph_id].items()})
def assign_data(type, src_data, dst_data):
for key, value in src_data.items():
dst_data[type].data[key] = value
for type, data in node_dict[graph_id].items():
assign_data(type, data['data'], graph.nodes)
for (type), data in edge_dict[graph_id].items():
assign_data(type, data['data'], graph.edges)
graph_dict[graph_id] = graph
return graph_dict
class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It
1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','.
3. read data and infer data type directly, otherwise.
"""
def __call__(self, df: pd.DataFrame):
data = {}
for header in df:
if 'Unnamed' in header:
dgl_warning("Unamed column is found. Ignored...")
continue
dt = df[header].to_numpy().squeeze()
if len(dt) > 0 and isinstance(dt[0], str):
#probably consists of list of numeric values
dt = np.array([ast.literal_eval(row) for row in dt])
data[header] = dt
return data
......@@ -8,10 +8,6 @@ from collections import OrderedDict
import itertools
import abc
import re
try:
import rdflib as rdf
except ImportError:
pass
import networkx as nx
import numpy as np
......@@ -115,7 +111,6 @@ class RDFGraphDataset(DGLBuiltinDataset):
raw_dir=None,
force_reload=False,
verbose=True):
import rdflib as rdf
self._insert_reverse = insert_reverse
self._print_every = print_every
self._predict_category = predict_category
......@@ -141,6 +136,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
-------
Loaded rdf data
"""
import rdflib as rdf
raw_rdf_graphs = []
for _, filename in enumerate(os.listdir(root_path)):
fmt = None
......@@ -674,6 +670,7 @@ class AIFBDataset(RDFGraphDataset):
return super(AIFBDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal")
if isinstance(term, rdf.BNode):
......@@ -855,6 +852,7 @@ class MUTAGDataset(RDFGraphDataset):
return super(MUTAGDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal")
elif isinstance(term, rdf.BNode):
......@@ -1052,6 +1050,7 @@ class BGSDataset(RDFGraphDataset):
return super(BGSDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
......@@ -1247,6 +1246,7 @@ class AMDataset(RDFGraphDataset):
return super(AMDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
......
......@@ -8,7 +8,6 @@ import pandas as pd
import yaml
import pytest
import dgl.data as data
import dgl.data.csv_dataset as csv_ds
from dgl import DGLError
......@@ -164,6 +163,7 @@ def test_extract_archive():
def _test_construct_graphs_homo():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1
num_nodes = 100
num_edges = 1000
......@@ -179,13 +179,13 @@ def _test_construct_graphs_homo():
_, u_indices = np.unique(node_ids, return_index=True)
ndata = {'feat': t_ndata['feat'][u_indices],
'label': t_ndata['label'][u_indices]}
node_data = csv_ds.NodeData(node_ids, t_ndata)
node_data = NodeData(node_ids, t_ndata)
src_ids = np.random.choice(node_ids, size=num_edges)
dst_ids = np.random.choice(node_ids, size=num_edges)
edata = {'feat': np.random.rand(
num_edges, num_dims), 'label': np.random.randint(2, size=num_edges)}
edge_data = csv_ds.EdgeData(src_ids, dst_ids, edata)
graphs, data_dict = csv_ds.DGLGraphConstructor.construct_graphs(
edge_data = EdgeData(src_ids, dst_ids, edata)
graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data)
assert len(graphs) == 1
assert len(data_dict) == 0
......@@ -203,6 +203,7 @@ def _test_construct_graphs_homo():
def _test_construct_graphs_hetero():
from dgl.data.csv_dataset_base import NodeData, EdgeData, DGLGraphConstructor
# node_ids could be non-sorted, duplicated, not labeled from 0 to num_nodes-1
num_nodes = 100
num_edges = 1000
......@@ -223,7 +224,7 @@ def _test_construct_graphs_hetero():
_, u_indices = np.unique(node_ids, return_index=True)
ndata = {'feat': t_ndata['feat'][u_indices],
'label': t_ndata['label'][u_indices]}
node_data.append(csv_ds.NodeData(node_ids, t_ndata, type=ntype))
node_data.append(NodeData(node_ids, t_ndata, type=ntype))
node_ids_dict[ntype] = node_ids
ndata_dict[ntype] = ndata
etypes = [('user', 'follow', 'user'), ('user', 'like', 'item')]
......@@ -234,10 +235,10 @@ def _test_construct_graphs_hetero():
dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
edata = {'feat': np.random.rand(
num_edges, num_dims), 'label': np.random.randint(2, size=num_edges)}
edge_data.append(csv_ds.EdgeData(src_ids, dst_ids, edata,
edge_data.append(EdgeData(src_ids, dst_ids, edata,
type=(src_type, e_type, dst_type)))
edata_dict[(src_type, e_type, dst_type)] = edata
graphs, data_dict = csv_ds.DGLGraphConstructor.construct_graphs(
graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data)
assert len(graphs) == 1
assert len(data_dict) == 0
......@@ -259,6 +260,7 @@ def _test_construct_graphs_hetero():
def _test_construct_graphs_multiple():
from dgl.data.csv_dataset_base import NodeData, EdgeData, GraphData, DGLGraphConstructor
num_nodes = 100
num_edges = 1000
num_graphs = 10
......@@ -283,14 +285,14 @@ def _test_construct_graphs_multiple():
egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
ndata = {'feat': np.random.rand(num_nodes*num_graphs, num_dims),
'label': np.random.randint(2, size=num_nodes*num_graphs)}
node_data = csv_ds.NodeData(node_ids, ndata, graph_id=ngraph_ids)
node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
edata = {'feat': np.random.rand(
num_edges*num_graphs, num_dims), 'label': np.random.randint(2, size=num_edges*num_graphs)}
edge_data = csv_ds.EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
gdata = {'feat': np.random.rand(num_graphs, num_dims),
'label': np.random.randint(2, size=num_graphs)}
graph_data = csv_ds.GraphData(np.arange(num_graphs), gdata)
graphs, data_dict = csv_ds.DGLGraphConstructor.construct_graphs(
graph_data = GraphData(np.arange(num_graphs), gdata)
graphs, data_dict = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data)
assert len(graphs) == num_graphs
assert len(data_dict) == len(gdata)
......@@ -313,10 +315,10 @@ def _test_construct_graphs_multiple():
assert_data(edata, g.edata, num_edges)
# Graph IDs found in node/edge CSV but not in graph CSV
graph_data = csv_ds.GraphData(np.arange(num_graphs-2), {})
graph_data = GraphData(np.arange(num_graphs-2), {})
expect_except = False
try:
_, _ = csv_ds.DGLGraphConstructor.construct_graphs(
_, _ = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data)
except:
expect_except = True
......@@ -324,6 +326,7 @@ def _test_construct_graphs_multiple():
def _test_DefaultDataParser():
from dgl.data.csv_dataset_base import DefaultDataParser
# common csv
with tempfile.TemporaryDirectory() as test_dir:
csv_path = os.path.join(test_dir, "nodes.csv")
......@@ -337,7 +340,7 @@ def _test_DefaultDataParser():
'feat': [line.tolist() for line in feat],
})
df.to_csv(csv_path, index=False)
dp = csv_ds.DefaultDataParser()
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
assert np.array_equal(node_id, dt['node_id'])
......@@ -349,7 +352,7 @@ def _test_DefaultDataParser():
df = pd.DataFrame({'label': ['a', 'b', 'c'],
})
df.to_csv(csv_path, index=False)
dp = csv_ds.DefaultDataParser()
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
expect_except = False
try:
......@@ -363,13 +366,14 @@ def _test_DefaultDataParser():
df = pd.DataFrame({'label': [1, 2, 3],
})
df.to_csv(csv_path)
dp = csv_ds.DefaultDataParser()
dp = DefaultDataParser()
df = pd.read_csv(csv_path)
dt = dp(df)
assert len(dt) == 1
def _test_load_yaml_with_sanity_check():
from dgl.data.csv_dataset_base import load_yaml_with_sanity_check
with tempfile.TemporaryDirectory() as test_dir:
yaml_path = os.path.join(test_dir, 'meta.yaml')
# workable but meaningless usually
......@@ -377,7 +381,7 @@ def _test_load_yaml_with_sanity_check():
'node_data': [], 'edge_data': []}
with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, sort_keys=False)
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
assert meta.version == '1.0.0'
assert meta.dataset_name == 'default'
assert meta.separator == ','
......@@ -390,7 +394,7 @@ def _test_load_yaml_with_sanity_check():
}
with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, sort_keys=False)
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
for ndata in meta.node_data:
assert ndata.file_name == 'nodes.csv'
assert ndata.ntype == '_V'
......@@ -411,7 +415,7 @@ def _test_load_yaml_with_sanity_check():
}
with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, sort_keys=False)
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
assert len(meta.node_data) == 1
ndata = meta.node_data[0]
assert ndata.ntype == 'user'
......@@ -436,7 +440,7 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(ydata, f, sort_keys=False)
expect_except = False
try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
except:
expect_except = True
assert expect_except
......@@ -448,7 +452,7 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(yaml_data, f, sort_keys=False)
expect_except = False
try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
except DGLError:
expect_except = True
assert expect_except
......@@ -460,7 +464,7 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(yaml_data, f, sort_keys=False)
expect_except = False
try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
except DGLError:
expect_except = True
assert expect_except
......@@ -472,22 +476,23 @@ def _test_load_yaml_with_sanity_check():
yaml.dump(yaml_data, f, sort_keys=False)
expect_except = False
try:
meta = csv_ds.load_yaml_with_sanity_check(yaml_path)
meta = load_yaml_with_sanity_check(yaml_path)
except DGLError:
expect_except = True
assert expect_except
def _test_load_node_data_from_csv():
from dgl.data.csv_dataset_base import MetaNode, NodeData, DefaultDataParser
with tempfile.TemporaryDirectory() as test_dir:
num_nodes = 100
# minimum
df = pd.DataFrame({'node_id': np.arange(num_nodes)})
csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path)
node_data = csv_ds.NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser())
meta_node = MetaNode(file_name=csv_path)
node_data = NodeData.load_from_csv(
meta_node, DefaultDataParser())
assert np.array_equal(df['node_id'], node_data.id)
assert len(node_data.data) == 0
......@@ -496,9 +501,9 @@ def _test_load_node_data_from_csv():
'label': np.random.randint(3, size=num_nodes)})
csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path)
node_data = csv_ds.NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser())
meta_node = MetaNode(file_name=csv_path)
node_data = NodeData.load_from_csv(
meta_node, DefaultDataParser())
assert np.array_equal(df['node_id'], node_data.id)
assert len(node_data.data) == 1
assert np.array_equal(df['label'], node_data.data['label'])
......@@ -510,9 +515,9 @@ def _test_load_node_data_from_csv():
3, size=num_nodes), 'graph_id': np.full(num_nodes, 1)})
csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path)
node_data = csv_ds.NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser())
meta_node = MetaNode(file_name=csv_path)
node_data = NodeData.load_from_csv(
meta_node, DefaultDataParser())
assert np.array_equal(df['node_id'], node_data.id)
assert len(node_data.data) == 1
assert np.array_equal(df['label'], node_data.data['label'])
......@@ -523,17 +528,18 @@ def _test_load_node_data_from_csv():
df = pd.DataFrame({'label': np.random.randint(3, size=num_nodes)})
csv_path = os.path.join(test_dir, 'nodes.csv')
df.to_csv(csv_path, index=False)
meta_node = csv_ds.MetaNode(file_name=csv_path)
meta_node = MetaNode(file_name=csv_path)
expect_except = False
try:
csv_ds.NodeData.load_from_csv(
meta_node, csv_ds.DefaultDataParser())
NodeData.load_from_csv(
meta_node, DefaultDataParser())
except:
expect_except = True
assert expect_except
def _test_load_edge_data_from_csv():
from dgl.data.csv_dataset_base import MetaEdge, EdgeData, DefaultDataParser
with tempfile.TemporaryDirectory() as test_dir:
num_nodes = 100
num_edges = 1000
......@@ -543,9 +549,9 @@ def _test_load_edge_data_from_csv():
})
csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path)
edge_data = csv_ds.EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser())
meta_edge = MetaEdge(file_name=csv_path)
edge_data = EdgeData.load_from_csv(
meta_edge, DefaultDataParser())
assert np.array_equal(df['src_id'], edge_data.src)
assert np.array_equal(df['dst_id'], edge_data.dst)
assert len(edge_data.data) == 0
......@@ -556,9 +562,9 @@ def _test_load_edge_data_from_csv():
'label': np.random.randint(3, size=num_edges)})
csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path)
edge_data = csv_ds.EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser())
meta_edge = MetaEdge(file_name=csv_path)
edge_data = EdgeData.load_from_csv(
meta_edge, DefaultDataParser())
assert np.array_equal(df['src_id'], edge_data.src)
assert np.array_equal(df['dst_id'], edge_data.dst)
assert len(edge_data.data) == 1
......@@ -574,9 +580,9 @@ def _test_load_edge_data_from_csv():
'label': np.random.randint(3, size=num_edges)})
csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path)
edge_data = csv_ds.EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser())
meta_edge = MetaEdge(file_name=csv_path)
edge_data = EdgeData.load_from_csv(
meta_edge, DefaultDataParser())
assert np.array_equal(df['src_id'], edge_data.src)
assert np.array_equal(df['dst_id'], edge_data.dst)
assert len(edge_data.data) == 2
......@@ -590,11 +596,11 @@ def _test_load_edge_data_from_csv():
})
csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path)
meta_edge = MetaEdge(file_name=csv_path)
expect_except = False
try:
csv_ds.EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser())
EdgeData.load_from_csv(
meta_edge, DefaultDataParser())
except DGLError:
expect_except = True
assert expect_except
......@@ -602,26 +608,27 @@ def _test_load_edge_data_from_csv():
})
csv_path = os.path.join(test_dir, 'edges.csv')
df.to_csv(csv_path, index=False)
meta_edge = csv_ds.MetaEdge(file_name=csv_path)
meta_edge = MetaEdge(file_name=csv_path)
expect_except = False
try:
csv_ds.EdgeData.load_from_csv(
meta_edge, csv_ds.DefaultDataParser())
EdgeData.load_from_csv(
meta_edge, DefaultDataParser())
except DGLError:
expect_except = True
assert expect_except
def _test_load_graph_data_from_csv():
from dgl.data.csv_dataset_base import MetaGraph, GraphData, DefaultDataParser
with tempfile.TemporaryDirectory() as test_dir:
num_graphs = 100
# minimum
df = pd.DataFrame({'graph_id': np.arange(num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path)
graph_data = csv_ds.GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser())
meta_graph = MetaGraph(file_name=csv_path)
graph_data = GraphData.load_from_csv(
meta_graph, DefaultDataParser())
assert np.array_equal(df['graph_id'], graph_data.graph_id)
assert len(graph_data.data) == 0
......@@ -630,9 +637,9 @@ def _test_load_graph_data_from_csv():
'label': np.random.randint(3, size=num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path)
graph_data = csv_ds.GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser())
meta_graph = MetaGraph(file_name=csv_path)
graph_data = GraphData.load_from_csv(
meta_graph, DefaultDataParser())
assert np.array_equal(df['graph_id'], graph_data.graph_id)
assert len(graph_data.data) == 1
assert np.array_equal(df['label'], graph_data.data['label'])
......@@ -643,9 +650,9 @@ def _test_load_graph_data_from_csv():
'label': np.random.randint(3, size=num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path)
graph_data = csv_ds.GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser())
meta_graph = MetaGraph(file_name=csv_path)
graph_data = GraphData.load_from_csv(
meta_graph, DefaultDataParser())
assert np.array_equal(df['graph_id'], graph_data.graph_id)
assert len(graph_data.data) == 2
assert np.array_equal(df['feat'], graph_data.data['feat'])
......@@ -655,11 +662,11 @@ def _test_load_graph_data_from_csv():
df = pd.DataFrame({'label': np.random.randint(3, size=num_graphs)})
csv_path = os.path.join(test_dir, 'graph.csv')
df.to_csv(csv_path, index=False)
meta_graph = csv_ds.MetaGraph(file_name=csv_path)
meta_graph = MetaGraph(file_name=csv_path)
expect_except = False
try:
csv_ds.GraphData.load_from_csv(
meta_graph, csv_ds.DefaultDataParser())
GraphData.load_from_csv(
meta_graph, DefaultDataParser())
except DGLError:
expect_except = True
assert expect_except
......@@ -716,7 +723,7 @@ def _test_DGLCSVDataset_single():
# remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0)
csv_dataset = csv_ds.DGLCSVDataset(
csv_dataset = data.DGLCSVDataset(
test_dir, force_reload=force_reload)
assert len(csv_dataset) == 1
g = csv_dataset[0]
......@@ -799,7 +806,7 @@ def _test_DGLCSVDataset_multiple():
# remove original node data file to verify reload from cached files
os.remove(nodes_csv_path_0)
assert not os.path.exists(nodes_csv_path_0)
csv_dataset = csv_ds.DGLCSVDataset(
csv_dataset = data.DGLCSVDataset(
test_dir, force_reload=force_reload)
assert len(csv_dataset) == num_graphs
assert csv_dataset.has_cache()
......@@ -885,7 +892,7 @@ def _test_DGLCSVDataset_customized_data_parser():
data[header] = dt
return data
# load CSVDataset with customized node/edge/graph_data_parser
csv_dataset = csv_ds.DGLCSVDataset(
csv_dataset = data.DGLCSVDataset(
test_dir, node_data_parser={'user': CustDataParser()}, edge_data_parser={('user', 'like', 'item'): CustDataParser()}, graph_data_parser=CustDataParser())
assert len(csv_dataset) == num_graphs
assert len(csv_dataset.data) == 1
......@@ -906,10 +913,11 @@ def _test_DGLCSVDataset_customized_data_parser():
def _test_NodeEdgeGraphData():
from dgl.data.csv_dataset_base import NodeData, EdgeData, GraphData
# NodeData basics
num_nodes = 100
node_ids = np.arange(num_nodes, dtype=np.float)
ndata = csv_ds.NodeData(node_ids, {})
ndata = NodeData(node_ids, {})
assert ndata.id.dtype == np.int64
assert np.array_equal(ndata.id, node_ids.astype(np.int64))
assert len(ndata.data) == 0
......@@ -918,7 +926,7 @@ def _test_NodeEdgeGraphData():
# NodeData more
data = {'feat': np.random.rand(num_nodes, 3)}
graph_id = np.arange(num_nodes)
ndata = csv_ds.NodeData(node_ids, data, type='user', graph_id=graph_id)
ndata = NodeData(node_ids, data, type='user', graph_id=graph_id)
assert ndata.type == 'user'
assert np.array_equal(ndata.graph_id, graph_id)
assert len(ndata.data) == len(data)
......@@ -928,7 +936,7 @@ def _test_NodeEdgeGraphData():
# NodeData except
expect_except = False
try:
csv_ds.NodeData(np.arange(num_nodes), {'feat': np.random.rand(
NodeData(np.arange(num_nodes), {'feat': np.random.rand(
num_nodes+1, 3)}, graph_id=np.arange(num_nodes-1))
except:
expect_except = True
......@@ -939,7 +947,7 @@ def _test_NodeEdgeGraphData():
num_edges = 1000
src_ids = np.random.randint(num_nodes, size=num_edges)
dst_ids = np.random.randint(num_nodes, size=num_edges)
edata = csv_ds.EdgeData(src_ids, dst_ids, {})
edata = EdgeData(src_ids, dst_ids, {})
assert np.array_equal(edata.src, src_ids)
assert np.array_equal(edata.dst, dst_ids)
assert edata.type == ('_V', '_E', '_V')
......@@ -951,7 +959,7 @@ def _test_NodeEdgeGraphData():
data = {'feat': np.random.rand(num_edges, 3)}
etype = ('user', 'like', 'item')
graph_ids = np.arange(num_edges)
edata = csv_ds.EdgeData(src_ids, dst_ids, data,
edata = EdgeData(src_ids, dst_ids, data,
type=etype, graph_id=graph_ids)
assert edata.src.dtype == np.int64
assert edata.dst.dtype == np.int64
......@@ -966,7 +974,7 @@ def _test_NodeEdgeGraphData():
# EdgeData except
expect_except = False
try:
csv_ds.EdgeData(np.arange(num_edges), np.arange(
EdgeData(np.arange(num_edges), np.arange(
num_edges+1), {'feat': np.random.rand(num_edges-1, 3)}, graph_id=np.arange(num_edges+2))
except:
expect_except = True
......@@ -975,13 +983,13 @@ def _test_NodeEdgeGraphData():
# GraphData basics
num_graphs = 10
graph_ids = np.arange(num_graphs)
gdata = csv_ds.GraphData(graph_ids, {})
gdata = GraphData(graph_ids, {})
assert np.array_equal(gdata.graph_id, graph_ids)
assert len(gdata.data) == 0
# GraphData more
graph_ids = np.arange(num_graphs).astype(np.float)
data = {'feat': np.random.rand(num_graphs, 3)}
gdata = csv_ds.GraphData(graph_ids, data)
gdata = GraphData(graph_ids, data)
assert gdata.graph_id.dtype == np.int64
assert np.array_equal(gdata.graph_id, graph_ids)
assert len(gdata.data) == len(data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment