Unverified Commit b347590a authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Dataset] Citation graph (#1902)



* citation graph

* GCN example use new citatoin dataset

* mxnet gat

* triger

* Fix

* Fix gat

* fix

* Fix tensorflow dgi

* Fix appnp, graphsage for mxnet

* fix monet and sgc for mxnet

* Fix tagcn

* update sgc, appnp
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 4be4b134
......@@ -17,6 +17,7 @@ from .bitcoinotc import BitcoinOTC
from .gdelt import GDELT
from .icews18 import ICEWS18
from .qm7b import QM7b
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
def register_data_args(parser):
......@@ -27,7 +28,6 @@ def register_data_args(parser):
help=
"The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit"
)
citegrh.register_args(parser)
def load_data(args):
......@@ -37,8 +37,6 @@ def load_data(args):
return citegrh.load_citeseer()
elif args.dataset == 'pubmed':
return citegrh.load_pubmed()
elif args.dataset == 'syn':
return citegrh.load_synthetic(args)
elif args.dataset is not None and args.dataset.startswith('reddit'):
return RedditDataset(self_loop=('self-loop' in args.dataset))
else:
......
......@@ -11,18 +11,17 @@ import networkx as nx
import scipy.sparse as sp
import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url
from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function
from .dgl_dataset import DGLBuiltinDataset
from .. import convert
from .. import batch
from .. import backend as F
from ..convert import graph as dgl_graph
from ..convert import to_networkx
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
'cora_binary' : 'dataset/cora_binary.zip',
}
backend = os.environ.get('DGLBACKEND', 'pytorch')
def _pickle_load(pkl_file):
if sys.version_info > (3, 0):
......@@ -30,7 +29,7 @@ def _pickle_load(pkl_file):
else:
return pkl.load(pkl_file)
class CitationGraphDataset(object):
class CitationGraphDataset(DGLBuiltinDataset):
r"""The citation graph dataset, including cora, citeseer and pubmeb.
Nodes mean authors and edges mean citation relationships.
......@@ -38,8 +37,21 @@ class CitationGraphDataset(object):
-----------
name: str
name can be 'cora', 'citeseer' or 'pubmed'.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
"""
def __init__(self, name):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
......@@ -47,18 +59,15 @@ class CitationGraphDataset(object):
if name.lower() == 'cora':
name = 'cora_v2'
self.name = name
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
self._load()
def _download_and_extract(self):
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
url = _get_dgl_url(self._urls[name])
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
@retry_method_with_fix(_download_and_extract)
def _load(self):
"""Loads input data from gcn/data directory
def process(self):
"""Loads input data from data directory
ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
......@@ -70,13 +79,8 @@ class CitationGraphDataset(object):
ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
object;
ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.
All objects above must be saved using python pickle module.
:param name: Dataset name
:return: All data input files loaded (as well the training/test data).
"""
root = '{}/{}'.format(self.dir, self.name)
root = self.raw_path
objnames = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = []
for i in range(len(objnames)):
......@@ -114,37 +118,135 @@ class CitationGraphDataset(object):
val_mask = _sample_mask(idx_val, labels.shape[0])
test_mask = _sample_mask(idx_test, labels.shape[0])
self.graph = graph
self.features = _preprocess_features(features)
self.labels = labels
self.onehot_labels = onehot_labels
self.num_labels = onehot_labels.shape[1]
self.train_mask = train_mask
self.val_mask = val_mask
self.test_mask = test_mask
print('Finished data loading and preprocessing.')
print(' NumNodes: {}'.format(self.graph.number_of_nodes()))
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
self._graph = graph
g = dgl_graph(graph)
g.ndata['train_mask'] = generate_mask_tensor(train_mask)
g.ndata['val_mask'] = generate_mask_tensor(val_mask)
g.ndata['test_mask'] = generate_mask_tensor(test_mask)
g.ndata['label'] = F.tensor(labels)
g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._labels = labels
self._g = g
if self.verbose:
print('Finished data loading and preprocessing.')
print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(
F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
print(' NumTestSamples: {}'.format(
F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))
def has_cache(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._g)
save_info(str(info_path), {'num_labels': self.num_labels})
def load(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path))
self._g = graphs[0]
graph = graph.clone()
graph.pop('train_mask')
graph.pop('val_mask')
graph.pop('test_mask')
graph.pop('feat')
graph.pop('label')
graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)
self._num_labels = info['num_labels']
self._g.ndata['train_mask'] = generate_mask_tensor(self._g.ndata['train_mask'].numpy())
self._g.ndata['val_mask'] = generate_mask_tensor(self._g.ndata['val_mask'].numpy())
self._g.ndata['test_mask'] = generate_mask_tensor(self._g.ndata['test_mask'].numpy())
# hack for mxnet compatability
if self.verbose:
print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(
F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
print(' NumTestSamples: {}'.format(
F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
g = convert.graph(self.graph)
g.ndata['train_mask'] = F.tensor(self.train_mask, F.bool)
g.ndata['val_mask'] = F.tensor(self.val_mask, F.bool)
g.ndata['test_mask'] = F.tensor(self.test_mask, F.bool)
g.ndata['label'] = F.tensor(self.labels, F.int64)
g.ndata['feat'] = F.tensor(self.features, F.float32)
return g
return self._g
def __len__(self):
return 1
@property
def save_name(self):
return self.name + '_dgl_graph'
@property
def num_labels(self):
return self._num_labels
""" Citation graph is used in many examples
We preserve these properties for compatability.
"""
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset.g')
return self._graph
@property
def train_mask(self):
deprecate_property('dataset.train_mask', 'g.ndata[\'train_mask\']')
return F.asnumpy(self._g.ndata['train_mask'])
@property
def val_mask(self):
deprecate_property('dataset.val_mask', 'g.ndata[\'val_mask\']')
return F.asnumpy(self._g.ndata['val_mask'])
@property
def test_mask(self):
deprecate_property('dataset.test_mask', 'g.ndata[\'test_mask\']')
return F.asnumpy(self._g.ndata['test_mask'])
@property
def labels(self):
deprecate_property('dataset.label', 'g.ndata[\'label\']')
return F.asnumpy(self._g.ndata['label'])
@property
def features(self):
deprecate_property('dataset.feat', 'g.ndata[\'feat\']')
return self._g.ndata['feat']
def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.asarray(features.sum(1))
......@@ -167,139 +269,436 @@ def _sample_mask(idx, l):
mask[idx] = 1
return mask
def load_cora():
data = CitationGraphDataset('cora')
return data
class CoraGraphDataset(CitationGraphDataset):
r""" Cora citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
Nodes mean paper and edges mean citation
relationships. Each node has a predefined
feature with 1433 dimensions. The dataset is
designed for the node classification task.
The task is to predict the category of
certain paper.
Statistics
----------
Nodes: 2708
Edges: 10556
Number of Classes: 7
Label Split: Train: 140 ,Valid: 500, Test: 1000
def load_citeseer():
data = CitationGraphDataset('citeseer')
return data
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels: int
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
Mask of training nodes
val_mask: Numpy array
Mask of validation nodes
test_mask: Numpy array
Mask of test nodes
labels: Numpy array
Ground truth labels of each node
features: Tensor
Node features
Notes
-----
The node feature is row-normalized.
Examples
--------
>>> dataset = CoraGraphDataset()
>>> g = dataset.graph
>>> num_class = g.num_labels
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
def load_pubmed():
data = CitationGraphDataset('pubmed')
return data
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
name = 'cora'
class GCNSyntheticDataset(object):
def __init__(self,
graph_generator,
num_feats=500,
num_classes=10,
train_ratio=1.,
val_ratio=0.,
test_ratio=0.,
seed=None):
rng = np.random.RandomState(seed)
# generate graph
self.graph = graph_generator(seed)
num_nodes = self.graph.number_of_nodes()
# generate features
#self.features = rng.randn(num_nodes, num_feats).astype(np.float32)
self.features = np.zeros((num_nodes, num_feats), dtype=np.float32)
# generate labels
self.labels = rng.randint(num_classes, size=num_nodes)
onehot_labels = np.zeros((num_nodes, num_classes), dtype=np.float32)
onehot_labels[np.arange(num_nodes), self.labels] = 1.
self.onehot_labels = onehot_labels
self.num_labels = num_classes
# generate masks
ntrain = int(num_nodes * train_ratio)
nval = int(num_nodes * val_ratio)
ntest = int(num_nodes * test_ratio)
mask_array = np.zeros((num_nodes,), dtype=np.int32)
mask_array[0:ntrain] = 1
mask_array[ntrain:ntrain+nval] = 2
mask_array[ntrain+nval:ntrain+nval+ntest] = 3
rng.shuffle(mask_array)
self.train_mask = (mask_array == 1).astype(np.int32)
self.val_mask = (mask_array == 2).astype(np.int32)
self.test_mask = (mask_array == 3).astype(np.int32)
print('Finished synthetic dataset generation.')
print(' NumNodes: {}'.format(self.graph.number_of_nodes()))
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
def __getitem__(self, idx):
return self
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, CoraGraphDataset has only one graph object
Return
------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
"""
return super(CoraGraphDataset, self).__getitem__(idx)
def __len__(self):
return 1
r"""The number of graphs in the dataset."""
return super(CoraGraphDataset, self).__len__()
class CiteseerGraphDataset(CitationGraphDataset):
r""" Citeseer citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
Nodes mean scientific publications and edges
mean citation relationships. Each node has a
predefined feature with 3703 dimensions. The
dataset is designed for the node classification
task. The task is to predict the category of
certain publication.
Statistics
----------
Nodes: 3327
Edges: 9228
Number of Classes: 6
Label Split: Train: 120 ,Valid: 500, Test: 1000
def get_gnp_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
def _gen(seed):
return nx.fast_gnp_random_graph(n, p, seed, True)
return _gen
class ScipyGraph(object):
"""A simple graph object that uses scipy matrix."""
def __init__(self, mat):
self._mat = mat
def get_graph(self):
return self._mat
def number_of_nodes(self):
return self._mat.shape[0]
def number_of_edges(self):
return self._mat.getnnz()
def get_scipy_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
def _gen(seed):
return ScipyGraph(sp.random(n, n, p, format='coo'))
return _gen
def load_synthetic(args):
ty = args.syn_type
if ty == 'gnp':
gen = get_gnp_generator(args)
elif ty == 'scipy':
gen = get_scipy_generator(args)
else:
raise ValueError('Unknown graph generator type: {}'.format(ty))
return GCNSyntheticDataset(
gen,
args.syn_nfeats,
args.syn_nclasses,
args.syn_train_ratio,
args.syn_val_ratio,
args.syn_test_ratio,
args.syn_seed)
def register_args(parser):
# Args for synthetic graphs.
parser.add_argument('--syn-type', type=str, default='gnp',
help='Type of the synthetic graph generator')
parser.add_argument('--syn-nfeats', type=int, default=500,
help='Number of node features')
parser.add_argument('--syn-nclasses', type=int, default=10,
help='Number of output classes')
parser.add_argument('--syn-train-ratio', type=float, default=.1,
help='Ratio of training nodes')
parser.add_argument('--syn-val-ratio', type=float, default=.2,
help='Ratio of validation nodes')
parser.add_argument('--syn-test-ratio', type=float, default=.5,
help='Ratio of testing nodes')
# Args for GNP generator
parser.add_argument('--syn-gnp-n', type=int, default=1000,
help='n in gnp random graph')
parser.add_argument('--syn-gnp-p', type=float, default=0.0,
help='p in gnp random graph')
parser.add_argument('--syn-seed', type=int, default=42,
help='random seed')
class CoraBinary(object):
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels: int
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
Mask of training nodes
val_mask: Numpy array
Mask of validation nodes
test_mask: Numpy array
Mask of test nodes
labels: Numpy array
Ground truth labels of each node
features: Tensor
Node features
Notes
-----
The node feature is row-normalized.
In citeseer dataset, there are some isolated nodes in the graph.
These isolated nodes are added as zero-vecs into the right position.
Examples
--------
>>> dataset = CiteseerGraphDataset()
>>> g = dataset.graph
>>> num_class = g.num_labels
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, CiteseerGraphDataset has only one graph object
Return
------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
"""
return super(CiteseerGraphDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(CiteseerGraphDataset, self).__len__()
class PubmedGraphDataset(CitationGraphDataset):
r""" Pubmed citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
Nodes mean scientific publications and edges
mean citation relationships. Each node has a
predefined feature with 500 dimensions. The
dataset is designed for the node classification
task. The task is to predict the category of
certain publication.
Statistics
----------
Nodes: 19717
Edges: 88651
Number of Classes: 3
Label Split: Train: 60 ,Valid: 500, Test: 1000
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels: int
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
Mask of training nodes
val_mask: Numpy array
Mask of validation nodes
test_mask: Numpy array
Mask of test nodes
labels: Numpy array
Ground truth labels of each node
features: Tensor
Node features
Notes
-----
The node feature is row-normalized.
Examples
--------
>>> dataset = PubmedGraphDataset()
>>> g = dataset.graph
>>> num_class = g.num_of_class
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, PubmedGraphDataset has only one graph object
Return
------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
"""
return super(PubmedGraphDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True):
"""Get CoraGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Return
-------
CoraGraphDataset
"""
data = CoraGraphDataset(raw_dir, force_reload, verbose)
return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True):
"""Get CiteseerGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Return
-------
CiteseerGraphDataset
"""
data = CiteseerGraphDataset(raw_dir, force_reload, verbose)
return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True):
"""Get PubmedGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Return
-------
PubmedGraphDataset
"""
data = PubmedGraphDataset(raw_dir, force_reload, verbose)
return data
class CoraBinary(DGLBuiltinDataset):
"""A mini-dataset for binary classification task using Cora.
After loaded, it has following members:
......@@ -307,20 +706,28 @@ class CoraBinary(object):
graphs : list of :class:`~dgl.DGLGraph`
pmpds : list of :class:`scipy.sparse.coo_matrix`
labels : list of :class:`numpy.ndarray`
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
"""
def __init__(self):
self.dir = get_download_dir()
self.name = 'cora_binary'
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name)
self._load()
def _download_and_extract(self):
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
@retry_method_with_fix(_download_and_extract)
def _load(self):
root = '{}/{}'.format(self.dir, self.name)
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
name = 'cora_binary'
url = _get_dgl_url('dataset/cora_binary.zip')
super(CoraBinary, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
root = self.raw_path
# load graphs
self.graphs = []
with open("{}/graphs.txt".format(root), 'r') as f:
......@@ -328,13 +735,13 @@ class CoraBinary(object):
for line in f.readlines():
if line.startswith('graph'):
if len(elist) != 0:
self.graphs.append(convert.graph(elist))
self.graphs.append(dgl_graph(elist))
elist = []
else:
u, v = line.strip().split(' ')
elist.append((int(u), int(v)))
if len(elist) != 0:
self.graphs.append(convert.graph(elist))
self.graphs.append(dgl_graph(elist))
with open("{}/pmpds.pkl".format(root), 'rb') as f:
self.pmpds = _pickle_load(f)
self.labels = []
......@@ -353,12 +760,64 @@ class CoraBinary(object):
assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels)
def has_cache(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
if os.path.exists(graph_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
labels = {}
for i, label in enumerate(self.labels):
labels['{}'.format(i)] = F.tensor(label)
save_graphs(str(graph_path), self.graphs, labels)
if self.verbose:
print('Done saving data into cached files.')
def load(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
self.graphs, labels = load_graphs(str(graph_path))
self.labels = []
for i in range(len(lables)):
self.labels.append(labels['{}'.format(i)].asnumpy())
# load pmpds under self.raw_path
with open("{}/pmpds.pkl".format(self.raw_path), 'rb') as f:
self.pmpds = _pickle_load(f)
if self.verbose:
print('Done loading data into cached files.')
# sanity check
assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels)
def __len__(self):
return len(self.graphs)
def __getitem__(self, i):
r"""Gets the idx-th sample.
Parameters
-----------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, scipy.sparse.coo_matrix, int)
The graph, scipy sparse coo_matrix and its label.
"""
return (self.graphs[i], self.pmpds[i], self.labels[i])
@property
def save_name(self):
return self.name + '_dgl_graph'
@staticmethod
def collate_fn(cur):
graphs, pmpds, labels = zip(*cur)
......
......@@ -249,6 +249,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
# save for compatability
self._train_idx = F.tensor(train_idx)
self._test_idx = F.tensor(test_idx)
self._labels = labels
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
"""Build the graphs
......@@ -638,17 +639,17 @@ class AIFBDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
"""
return super(AIFBDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(AIFBDataset, self).__len__(idx)
return super(AIFBDataset, self).__len__()
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
......@@ -801,17 +802,17 @@ class MUTAGDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
"""
return super(MUTAGDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(MUTAGDataset, self).__len__(idx)
return super(MUTAGDataset, self).__len__()
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
......@@ -980,17 +981,17 @@ class BGSDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
"""
return super(BGSDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(BGSDataset, self).__len__(idx)
return super(BGSDataset, self).__len__()
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
......@@ -1155,17 +1156,17 @@ class AMDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
"""
return super(AMDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(AMDataset, self).__len__(idx)
return super(AMDataset, self).__len__()
def parse_entity(self, term):
if isinstance(term, rdf.Literal):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment