Unverified Commit b347590a authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Dataset] Citation graph (#1902)



* citation graph

* GCN example use new citatoin dataset

* mxnet gat

* triger

* Fix

* Fix gat

* fix

* Fix tensorflow dgi

* Fix appnp, graphsage for mxnet

* fix monet and sgc for mxnet

* Fix tagcn

* update sgc, appnp
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 4be4b134
...@@ -17,6 +17,7 @@ from .bitcoinotc import BitcoinOTC ...@@ -17,6 +17,7 @@ from .bitcoinotc import BitcoinOTC
from .gdelt import GDELT from .gdelt import GDELT
from .icews18 import ICEWS18 from .icews18 import ICEWS18
from .qm7b import QM7b from .qm7b import QM7b
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
def register_data_args(parser): def register_data_args(parser):
...@@ -27,7 +28,6 @@ def register_data_args(parser): ...@@ -27,7 +28,6 @@ def register_data_args(parser):
help= help=
"The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit" "The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit"
) )
citegrh.register_args(parser)
def load_data(args): def load_data(args):
...@@ -37,8 +37,6 @@ def load_data(args): ...@@ -37,8 +37,6 @@ def load_data(args):
return citegrh.load_citeseer() return citegrh.load_citeseer()
elif args.dataset == 'pubmed': elif args.dataset == 'pubmed':
return citegrh.load_pubmed() return citegrh.load_pubmed()
elif args.dataset == 'syn':
return citegrh.load_synthetic(args)
elif args.dataset is not None and args.dataset.startswith('reddit'): elif args.dataset is not None and args.dataset.startswith('reddit'):
return RedditDataset(self_loop=('self-loop' in args.dataset)) return RedditDataset(self_loop=('self-loop' in args.dataset))
else: else:
......
...@@ -11,18 +11,17 @@ import networkx as nx ...@@ -11,18 +11,17 @@ import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
import os, sys import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url
from ..utils import retry_method_with_fix from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function
from .dgl_dataset import DGLBuiltinDataset
from .. import convert from .. import convert
from .. import batch from .. import batch
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph
from ..convert import to_networkx
_urls = { backend = os.environ.get('DGLBACKEND', 'pytorch')
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
'cora_binary' : 'dataset/cora_binary.zip',
}
def _pickle_load(pkl_file): def _pickle_load(pkl_file):
if sys.version_info > (3, 0): if sys.version_info > (3, 0):
...@@ -30,7 +29,7 @@ def _pickle_load(pkl_file): ...@@ -30,7 +29,7 @@ def _pickle_load(pkl_file):
else: else:
return pkl.load(pkl_file) return pkl.load(pkl_file)
class CitationGraphDataset(object): class CitationGraphDataset(DGLBuiltinDataset):
r"""The citation graph dataset, including cora, citeseer and pubmeb. r"""The citation graph dataset, including cora, citeseer and pubmeb.
Nodes mean authors and edges mean citation relationships. Nodes mean authors and edges mean citation relationships.
...@@ -38,8 +37,21 @@ class CitationGraphDataset(object): ...@@ -38,8 +37,21 @@ class CitationGraphDataset(object):
----------- -----------
name: str name: str
name can be 'cora', 'citeseer' or 'pubmed'. name can be 'cora', 'citeseer' or 'pubmed'.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
""" """
def __init__(self, name): _urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed'] assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn) # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
...@@ -47,18 +59,15 @@ class CitationGraphDataset(object): ...@@ -47,18 +59,15 @@ class CitationGraphDataset(object):
if name.lower() == 'cora': if name.lower() == 'cora':
name = 'cora_v2' name = 'cora_v2'
self.name = name url = _get_dgl_url(self._urls[name])
self.dir = get_download_dir() super(CitationGraphDataset, self).__init__(name,
self.zip_file_path='{}/{}.zip'.format(self.dir, name) url=url,
self._load() raw_dir=raw_dir,
force_reload=force_reload,
def _download_and_extract(self): verbose=verbose)
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
@retry_method_with_fix(_download_and_extract) def process(self):
def _load(self): """Loads input data from data directory
"""Loads input data from gcn/data directory
ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
...@@ -70,13 +79,8 @@ class CitationGraphDataset(object): ...@@ -70,13 +79,8 @@ class CitationGraphDataset(object):
ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
object; object;
ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object. ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.
All objects above must be saved using python pickle module.
:param name: Dataset name
:return: All data input files loaded (as well the training/test data).
""" """
root = '{}/{}'.format(self.dir, self.name) root = self.raw_path
objnames = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] objnames = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = [] objects = []
for i in range(len(objnames)): for i in range(len(objnames)):
...@@ -114,37 +118,135 @@ class CitationGraphDataset(object): ...@@ -114,37 +118,135 @@ class CitationGraphDataset(object):
val_mask = _sample_mask(idx_val, labels.shape[0]) val_mask = _sample_mask(idx_val, labels.shape[0])
test_mask = _sample_mask(idx_test, labels.shape[0]) test_mask = _sample_mask(idx_test, labels.shape[0])
self.graph = graph self._graph = graph
self.features = _preprocess_features(features) g = dgl_graph(graph)
self.labels = labels
self.onehot_labels = onehot_labels g.ndata['train_mask'] = generate_mask_tensor(train_mask)
self.num_labels = onehot_labels.shape[1] g.ndata['val_mask'] = generate_mask_tensor(val_mask)
self.train_mask = train_mask g.ndata['test_mask'] = generate_mask_tensor(test_mask)
self.val_mask = val_mask g.ndata['label'] = F.tensor(labels)
self.test_mask = test_mask g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
print('Finished data loading and preprocessing.') self._labels = labels
print(' NumNodes: {}'.format(self.graph.number_of_nodes())) self._g = g
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1])) if self.verbose:
print(' NumClasses: {}'.format(self.num_labels)) print('Finished data loading and preprocessing.')
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0]))) print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0]))) print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0]))) print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(
F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
print(' NumTestSamples: {}'.format(
F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))
def has_cache(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._g)
save_info(str(info_path), {'num_labels': self.num_labels})
def load(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path))
self._g = graphs[0]
graph = graph.clone()
graph.pop('train_mask')
graph.pop('val_mask')
graph.pop('test_mask')
graph.pop('feat')
graph.pop('label')
graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)
self._num_labels = info['num_labels']
self._g.ndata['train_mask'] = generate_mask_tensor(self._g.ndata['train_mask'].numpy())
self._g.ndata['val_mask'] = generate_mask_tensor(self._g.ndata['val_mask'].numpy())
self._g.ndata['test_mask'] = generate_mask_tensor(self._g.ndata['test_mask'].numpy())
# hack for mxnet compatability
if self.verbose:
print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(
F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
print(' NumTestSamples: {}'.format(
F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))
def __getitem__(self, idx): def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
g = convert.graph(self.graph) return self._g
g.ndata['train_mask'] = F.tensor(self.train_mask, F.bool)
g.ndata['val_mask'] = F.tensor(self.val_mask, F.bool)
g.ndata['test_mask'] = F.tensor(self.test_mask, F.bool)
g.ndata['label'] = F.tensor(self.labels, F.int64)
g.ndata['feat'] = F.tensor(self.features, F.float32)
return g
def __len__(self): def __len__(self):
return 1 return 1
@property
def save_name(self):
return self.name + '_dgl_graph'
@property
def num_labels(self):
return self._num_labels
""" Citation graph is used in many examples
We preserve these properties for compatability.
"""
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset.g')
return self._graph
@property
def train_mask(self):
deprecate_property('dataset.train_mask', 'g.ndata[\'train_mask\']')
return F.asnumpy(self._g.ndata['train_mask'])
@property
def val_mask(self):
deprecate_property('dataset.val_mask', 'g.ndata[\'val_mask\']')
return F.asnumpy(self._g.ndata['val_mask'])
@property
def test_mask(self):
deprecate_property('dataset.test_mask', 'g.ndata[\'test_mask\']')
return F.asnumpy(self._g.ndata['test_mask'])
@property
def labels(self):
deprecate_property('dataset.label', 'g.ndata[\'label\']')
return F.asnumpy(self._g.ndata['label'])
@property
def features(self):
deprecate_property('dataset.feat', 'g.ndata[\'feat\']')
return self._g.ndata['feat']
def _preprocess_features(features): def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation""" """Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.asarray(features.sum(1)) rowsum = np.asarray(features.sum(1))
...@@ -167,139 +269,436 @@ def _sample_mask(idx, l): ...@@ -167,139 +269,436 @@ def _sample_mask(idx, l):
mask[idx] = 1 mask[idx] = 1
return mask return mask
def load_cora(): class CoraGraphDataset(CitationGraphDataset):
data = CitationGraphDataset('cora') r""" Cora citation network dataset.
return data
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
Nodes mean paper and edges mean citation
relationships. Each node has a predefined
feature with 1433 dimensions. The dataset is
designed for the node classification task.
The task is to predict the category of
certain paper.
Statistics
----------
Nodes: 2708
Edges: 10556
Number of Classes: 7
Label Split: Train: 140 ,Valid: 500, Test: 1000
def load_citeseer(): Parameters
data = CitationGraphDataset('citeseer') ----------
return data raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels: int
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
Mask of training nodes
val_mask: Numpy array
Mask of validation nodes
test_mask: Numpy array
Mask of test nodes
labels: Numpy array
Ground truth labels of each node
features: Tensor
Node features
Notes
-----
The node feature is row-normalized.
Examples
--------
>>> dataset = CoraGraphDataset()
>>> g = dataset.graph
>>> num_class = g.num_labels
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
def load_pubmed(): """
data = CitationGraphDataset('pubmed') def __init__(self, raw_dir=None, force_reload=False, verbose=True):
return data name = 'cora'
class GCNSyntheticDataset(object): super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
def __init__(self,
graph_generator,
num_feats=500,
num_classes=10,
train_ratio=1.,
val_ratio=0.,
test_ratio=0.,
seed=None):
rng = np.random.RandomState(seed)
# generate graph
self.graph = graph_generator(seed)
num_nodes = self.graph.number_of_nodes()
# generate features
#self.features = rng.randn(num_nodes, num_feats).astype(np.float32)
self.features = np.zeros((num_nodes, num_feats), dtype=np.float32)
# generate labels
self.labels = rng.randint(num_classes, size=num_nodes)
onehot_labels = np.zeros((num_nodes, num_classes), dtype=np.float32)
onehot_labels[np.arange(num_nodes), self.labels] = 1.
self.onehot_labels = onehot_labels
self.num_labels = num_classes
# generate masks
ntrain = int(num_nodes * train_ratio)
nval = int(num_nodes * val_ratio)
ntest = int(num_nodes * test_ratio)
mask_array = np.zeros((num_nodes,), dtype=np.int32)
mask_array[0:ntrain] = 1
mask_array[ntrain:ntrain+nval] = 2
mask_array[ntrain+nval:ntrain+nval+ntest] = 3
rng.shuffle(mask_array)
self.train_mask = (mask_array == 1).astype(np.int32)
self.val_mask = (mask_array == 2).astype(np.int32)
self.test_mask = (mask_array == 3).astype(np.int32)
print('Finished synthetic dataset generation.')
print(' NumNodes: {}'.format(self.graph.number_of_nodes()))
print(' NumEdges: {}'.format(self.graph.number_of_edges()))
print(' NumFeats: {}'.format(self.features.shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumTrainingSamples: {}'.format(len(np.nonzero(self.train_mask)[0])))
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx): def __getitem__(self, idx):
return self r"""Gets the graph object
Parameters
-----------
idx: int
Item index, CoraGraphDataset has only one graph object
Return
------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
"""
return super(CoraGraphDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
return 1 r"""The number of graphs in the dataset."""
return super(CoraGraphDataset, self).__len__()
class CiteseerGraphDataset(CitationGraphDataset):
r""" Citeseer citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
Nodes mean scientific publications and edges
mean citation relationships. Each node has a
predefined feature with 3703 dimensions. The
dataset is designed for the node classification
task. The task is to predict the category of
certain publication.
Statistics
----------
Nodes: 3327
Edges: 9228
Number of Classes: 6
Label Split: Train: 120 ,Valid: 500, Test: 1000
def get_gnp_generator(args): Parameters
n = args.syn_gnp_n -----------
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p raw_dir : str
def _gen(seed): Raw file directory to download/contains the input data directory.
return nx.fast_gnp_random_graph(n, p, seed, True) Default: ~/.dgl/
return _gen force_reload : bool
Whether to reload the dataset. Default: False
class ScipyGraph(object): verbose: bool
"""A simple graph object that uses scipy matrix.""" Whether to print out progress information. Default: True.
def __init__(self, mat):
self._mat = mat Attributes
----------
def get_graph(self): num_labels: int
return self._mat Number of label classes
graph: networkx.DiGraph
def number_of_nodes(self): Graph structure
return self._mat.shape[0] train_mask: Numpy array
Mask of training nodes
def number_of_edges(self): val_mask: Numpy array
return self._mat.getnnz() Mask of validation nodes
test_mask: Numpy array
def get_scipy_generator(args): Mask of test nodes
n = args.syn_gnp_n labels: Numpy array
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p Ground truth labels of each node
def _gen(seed): features: Tensor
return ScipyGraph(sp.random(n, n, p, format='coo')) Node features
return _gen
Notes
def load_synthetic(args): -----
ty = args.syn_type The node feature is row-normalized.
if ty == 'gnp':
gen = get_gnp_generator(args) In citeseer dataset, there are some isolated nodes in the graph.
elif ty == 'scipy': These isolated nodes are added as zero-vecs into the right position.
gen = get_scipy_generator(args)
else: Examples
raise ValueError('Unknown graph generator type: {}'.format(ty)) --------
return GCNSyntheticDataset( >>> dataset = CiteseerGraphDataset()
gen, >>> g = dataset.graph
args.syn_nfeats, >>> num_class = g.num_labels
args.syn_nclasses, >>>
args.syn_train_ratio, >>> # get node feature
args.syn_val_ratio, >>> feat = g.ndata['feat']
args.syn_test_ratio, >>>
args.syn_seed) >>> # get data split
>>> train_mask = g.ndata['train_mask']
def register_args(parser): >>> val_mask = g.ndata['val_mask']
# Args for synthetic graphs. >>> test_mask = g.ndata['test_mask']
parser.add_argument('--syn-type', type=str, default='gnp', >>>
help='Type of the synthetic graph generator') >>> # get labels
parser.add_argument('--syn-nfeats', type=int, default=500, >>> label = g.ndata['label']
help='Number of node features') >>>
parser.add_argument('--syn-nclasses', type=int, default=10, >>> # Train, Validation and Test
help='Number of output classes')
parser.add_argument('--syn-train-ratio', type=float, default=.1, """
help='Ratio of training nodes') def __init__(self, raw_dir=None, force_reload=False, verbose=True):
parser.add_argument('--syn-val-ratio', type=float, default=.2, name = 'citeseer'
help='Ratio of validation nodes')
parser.add_argument('--syn-test-ratio', type=float, default=.5, super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
help='Ratio of testing nodes')
# Args for GNP generator def __getitem__(self, idx):
parser.add_argument('--syn-gnp-n', type=int, default=1000, r"""Gets the graph object
help='n in gnp random graph')
parser.add_argument('--syn-gnp-p', type=float, default=0.0, Parameters
help='p in gnp random graph') -----------
parser.add_argument('--syn-seed', type=int, default=42, idx: int
help='random seed') Item index, CiteseerGraphDataset has only one graph object
class CoraBinary(object): Return
------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
"""
return super(CiteseerGraphDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(CiteseerGraphDataset, self).__len__()
class PubmedGraphDataset(CitationGraphDataset):
r""" Pubmed citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
Nodes mean scientific publications and edges
mean citation relationships. Each node has a
predefined feature with 500 dimensions. The
dataset is designed for the node classification
task. The task is to predict the category of
certain publication.
Statistics
----------
Nodes: 19717
Edges: 88651
Number of Classes: 3
Label Split: Train: 60 ,Valid: 500, Test: 1000
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels: int
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
Mask of training nodes
val_mask: Numpy array
Mask of validation nodes
test_mask: Numpy array
Mask of test nodes
labels: Numpy array
Ground truth labels of each node
features: Tensor
Node features
Notes
-----
The node feature is row-normalized.
Examples
--------
>>> dataset = PubmedGraphDataset()
>>> g = dataset.graph
>>> num_class = g.num_of_class
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, PubmedGraphDataset has only one graph object
Return
------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
"""
return super(PubmedGraphDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True):
"""Get CoraGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Return
-------
CoraGraphDataset
"""
data = CoraGraphDataset(raw_dir, force_reload, verbose)
return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True):
"""Get CiteseerGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Return
-------
CiteseerGraphDataset
"""
data = CiteseerGraphDataset(raw_dir, force_reload, verbose)
return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True):
"""Get PubmedGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Return
-------
PubmedGraphDataset
"""
data = PubmedGraphDataset(raw_dir, force_reload, verbose)
return data
class CoraBinary(DGLBuiltinDataset):
"""A mini-dataset for binary classification task using Cora. """A mini-dataset for binary classification task using Cora.
After loaded, it has following members: After loaded, it has following members:
...@@ -307,20 +706,28 @@ class CoraBinary(object): ...@@ -307,20 +706,28 @@ class CoraBinary(object):
graphs : list of :class:`~dgl.DGLGraph` graphs : list of :class:`~dgl.DGLGraph`
pmpds : list of :class:`scipy.sparse.coo_matrix` pmpds : list of :class:`scipy.sparse.coo_matrix`
labels : list of :class:`numpy.ndarray` labels : list of :class:`numpy.ndarray`
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
""" """
def __init__(self): def __init__(self, raw_dir=None, force_reload=False, verbose=True):
self.dir = get_download_dir() name = 'cora_binary'
self.name = 'cora_binary' url = _get_dgl_url('dataset/cora_binary.zip')
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name) super(CoraBinary, self).__init__(name,
self._load() url=url,
raw_dir=raw_dir,
def _download_and_extract(self): force_reload=force_reload,
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path) verbose=verbose)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
def process(self):
@retry_method_with_fix(_download_and_extract) root = self.raw_path
def _load(self):
root = '{}/{}'.format(self.dir, self.name)
# load graphs # load graphs
self.graphs = [] self.graphs = []
with open("{}/graphs.txt".format(root), 'r') as f: with open("{}/graphs.txt".format(root), 'r') as f:
...@@ -328,13 +735,13 @@ class CoraBinary(object): ...@@ -328,13 +735,13 @@ class CoraBinary(object):
for line in f.readlines(): for line in f.readlines():
if line.startswith('graph'): if line.startswith('graph'):
if len(elist) != 0: if len(elist) != 0:
self.graphs.append(convert.graph(elist)) self.graphs.append(dgl_graph(elist))
elist = [] elist = []
else: else:
u, v = line.strip().split(' ') u, v = line.strip().split(' ')
elist.append((int(u), int(v))) elist.append((int(u), int(v)))
if len(elist) != 0: if len(elist) != 0:
self.graphs.append(convert.graph(elist)) self.graphs.append(dgl_graph(elist))
with open("{}/pmpds.pkl".format(root), 'rb') as f: with open("{}/pmpds.pkl".format(root), 'rb') as f:
self.pmpds = _pickle_load(f) self.pmpds = _pickle_load(f)
self.labels = [] self.labels = []
...@@ -353,12 +760,64 @@ class CoraBinary(object): ...@@ -353,12 +760,64 @@ class CoraBinary(object):
assert len(self.graphs) == len(self.pmpds) assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels) assert len(self.graphs) == len(self.labels)
def has_cache(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
if os.path.exists(graph_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
labels = {}
for i, label in enumerate(self.labels):
labels['{}'.format(i)] = F.tensor(label)
save_graphs(str(graph_path), self.graphs, labels)
if self.verbose:
print('Done saving data into cached files.')
def load(self):
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
self.graphs, labels = load_graphs(str(graph_path))
self.labels = []
for i in range(len(lables)):
self.labels.append(labels['{}'.format(i)].asnumpy())
# load pmpds under self.raw_path
with open("{}/pmpds.pkl".format(self.raw_path), 'rb') as f:
self.pmpds = _pickle_load(f)
if self.verbose:
print('Done loading data into cached files.')
# sanity check
assert len(self.graphs) == len(self.pmpds)
assert len(self.graphs) == len(self.labels)
def __len__(self): def __len__(self):
return len(self.graphs) return len(self.graphs)
def __getitem__(self, i): def __getitem__(self, i):
r"""Gets the idx-th sample.
Parameters
-----------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, scipy.sparse.coo_matrix, int)
The graph, scipy sparse coo_matrix and its label.
"""
return (self.graphs[i], self.pmpds[i], self.labels[i]) return (self.graphs[i], self.pmpds[i], self.labels[i])
@property
def save_name(self):
return self.name + '_dgl_graph'
@staticmethod @staticmethod
def collate_fn(cur): def collate_fn(cur):
graphs, pmpds, labels = zip(*cur) graphs, pmpds, labels = zip(*cur)
......
...@@ -249,6 +249,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -249,6 +249,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
# save for compatability # save for compatability
self._train_idx = F.tensor(train_idx) self._train_idx = F.tensor(train_idx)
self._test_idx = F.tensor(test_idx) self._test_idx = F.tensor(test_idx)
self._labels = labels
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes): def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
"""Build the graphs """Build the graphs
...@@ -638,17 +639,17 @@ class AIFBDataset(RDFGraphDataset): ...@@ -638,17 +639,17 @@ class AIFBDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(AIFBDataset, self).__getitem__(idx) return super(AIFBDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(AIFBDataset, self).__len__(idx) return super(AIFBDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
...@@ -801,17 +802,17 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -801,17 +802,17 @@ class MUTAGDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(MUTAGDataset, self).__getitem__(idx) return super(MUTAGDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(MUTAGDataset, self).__len__(idx) return super(MUTAGDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
...@@ -980,17 +981,17 @@ class BGSDataset(RDFGraphDataset): ...@@ -980,17 +981,17 @@ class BGSDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(BGSDataset, self).__getitem__(idx) return super(BGSDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(BGSDataset, self).__len__(idx) return super(BGSDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
...@@ -1155,17 +1156,17 @@ class AMDataset(RDFGraphDataset): ...@@ -1155,17 +1156,17 @@ class AMDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph dgl.DGLGraph
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set - ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set - ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ndata['labels']: mask for labels
""" """
return super(AMDataset, self).__getitem__(idx) return super(AMDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(AMDataset, self).__len__(idx) return super(AMDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment