Unverified Commit 90d2118d authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Change Cora Split (#1583)

* Use new version of Cora

* Fix import
parent b8dffcd5
......@@ -49,13 +49,6 @@ Citation Network dataset
:members: __getitem__, __len__
Cora Citation Network dataset
```````````````````````````````````
.. autoclass:: CoraDataset
:members: __getitem__, __len__
CoraFull dataset
```````````````````````````````````
......
......@@ -8,7 +8,7 @@ Dataset (Temporary)
+================+==========================================================+===========+===============+===============+============================================+===========+========+
|BitcoinOTC |BitcoinOTC() | 136| 6005.00| 21209.98| |h |True |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Cora |CoraDataset() | 1| 2708.00| 10556.00|train_mask, val_mask, test_mask, label, feat| |False |
|Cora |CitationGraphDataset('cora') | 1| 2708.00| 10556.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Citeseer |CitationGraphDataset('citeseer') | 1| 3327.00| 9228.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
......
......@@ -10,11 +10,11 @@ from dgl.data.gdelt import GDELT
from dgl.data.icews18 import ICEWS18
from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9
from dgl.data import CitationGraphDataset, CoraDataset, PPIDataset, RedditDataset, TUDataset
from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
ds_list = {
"BitcoinOTC": "BitcoinOTC()",
"Cora": "CoraDataset()",
"Cora": "CitationGraphDataset('cora')",
"Citeseer": "CitationGraphDataset('citeseer')",
"PubMed": "CitationGraphDataset('pubmed')",
"QM7b": "QM7b()",
......
......@@ -22,8 +22,8 @@ def load_dataset(name):
data = RedditDataset(self_loop=True)
g = data.graph
else:
from dgl.data import CoraDataset
data = CoraDataset()
from dgl.data import CitationGraphDataset
data = CitationGraphDataset('cora')
g = dgl.DGLGraph(data.graph)
train_mask = data.train_mask
val_mask = data.val_mask
......
......@@ -2,7 +2,7 @@
from __future__ import absolute_import
from . import citation_graph as citegrh
from .citation_graph import CoraBinary, CitationGraphDataset, CoraDataset
from .citation_graph import CoraBinary, CitationGraphDataset
from .minigc import *
from .tree import *
from .utils import *
......
......@@ -16,7 +16,7 @@ from ..graph import DGLGraph
from ..graph import batch as graph_batch
_urls = {
'cora' : 'dataset/cora_raw.zip',
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
'cora_binary' : 'dataset/cora_binary.zip',
......@@ -29,16 +29,22 @@ def _pickle_load(pkl_file):
return pkl.load(pkl_file)
class CitationGraphDataset(object):
r"""The citation graph dataset, including citeseer and pubmeb.
r"""The citation graph dataset, including cora, citeseer and pubmeb.
Nodes mean authors and edges mean citation relationships.
Parameters
-----------
name: str
name can be 'citeseer' or 'pubmed'.
name can be 'cora', 'citeseer' or 'pubmed'.
"""
def __init__(self, name):
assert name.lower() in ['citeseer', 'pubmed']
assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
# for Cora, which is slightly different from the one used in the GCN paper
if name.lower() == 'cora':
name = 'cora_v2'
self.name = name
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
......@@ -157,7 +163,7 @@ def _sample_mask(idx, l):
return mask
def load_cora():
data = CoraDataset()
data = CitationGraphDataset('cora')
return data
def load_citeseer():
......@@ -353,67 +359,6 @@ class CoraBinary(object):
batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels
class CoraDataset(object):
r"""Cora citation network dataset. Nodes mean author and edges mean citation
relationships.
"""
def __init__(self):
self.name = 'cora'
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name)
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path,
'{}/{}'.format(self.dir, self.name))
self._load()
def _load(self):
idx_features_labels = np.genfromtxt("{}/cora/cora.content".
format(self.dir),
dtype=np.dtype(str))
features = sp.csr_matrix(idx_features_labels[:, 1:-1],
dtype=np.float32)
labels = _encode_onehot(idx_features_labels[:, -1])
self.num_labels = labels.shape[1]
# build graph
idx = np.asarray(idx_features_labels[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}
edges_unordered = np.genfromtxt("{}/cora/cora.cites".format(self.dir),
dtype=np.int32)
edges = np.asarray(list(map(idx_map.get, edges_unordered.flatten())),
dtype=np.int32).reshape(edges_unordered.shape)
adj = sp.coo_matrix((np.ones(edges.shape[0]),
(edges[:, 0], edges[:, 1])),
shape=(labels.shape[0], labels.shape[0]),
dtype=np.float32)
# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
self.graph = nx.from_scipy_sparse_matrix(adj, create_using=nx.DiGraph())
features = _normalize(features)
self.features = np.asarray(features.todense())
self.labels = np.where(labels)[1]
self.train_mask = _sample_mask(range(140), labels.shape[0])
self.val_mask = _sample_mask(range(200, 500), labels.shape[0])
self.test_mask = _sample_mask(range(500, 1500), labels.shape[0])
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
g = DGLGraph(self.graph)
g.ndata['train_mask'] = self.train_mask
g.ndata['val_mask'] = self.val_mask
g.ndata['test_mask'] = self.test_mask
g.ndata['label'] = self.labels
g.ndata['feat'] = self.features
return g
def __len__(self):
return 1
def _normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.asarray(mx.sum(1))
......@@ -430,4 +375,3 @@ def _encode_onehot(labels):
labels_onehot = np.asarray(list(map(classes_dict.get, labels)),
dtype=np.int32)
return labels_onehot
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment