"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ef3844d3a83583f36d0166be6753d062b3cbd7dc"
Unverified Commit 90d2118d authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Change Cora Split (#1583)

* Use new version of Cora

* Fix import
parent b8dffcd5
...@@ -49,13 +49,6 @@ Citation Network dataset ...@@ -49,13 +49,6 @@ Citation Network dataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Cora Citation Network dataset
```````````````````````````````````
.. autoclass:: CoraDataset
:members: __getitem__, __len__
CoraFull dataset CoraFull dataset
``````````````````````````````````` ```````````````````````````````````
......
...@@ -8,7 +8,7 @@ Dataset (Temporary) ...@@ -8,7 +8,7 @@ Dataset (Temporary)
+================+==========================================================+===========+===============+===============+============================================+===========+========+ +================+==========================================================+===========+===============+===============+============================================+===========+========+
|BitcoinOTC |BitcoinOTC() | 136| 6005.00| 21209.98| |h |True | |BitcoinOTC |BitcoinOTC() | 136| 6005.00| 21209.98| |h |True |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+ +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Cora |CoraDataset() | 1| 2708.00| 10556.00|train_mask, val_mask, test_mask, label, feat| |False | |Cora |CitationGraphDataset('cora') | 1| 2708.00| 10556.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+ +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Citeseer |CitationGraphDataset('citeseer') | 1| 3327.00| 9228.00|train_mask, val_mask, test_mask, label, feat| |False | |Citeseer |CitationGraphDataset('citeseer') | 1| 3327.00| 9228.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+ +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
......
...@@ -10,11 +10,11 @@ from dgl.data.gdelt import GDELT ...@@ -10,11 +10,11 @@ from dgl.data.gdelt import GDELT
from dgl.data.icews18 import ICEWS18 from dgl.data.icews18 import ICEWS18
from dgl.data.qm7b import QM7b from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9 # from dgl.data.qm9 import QM9
from dgl.data import CitationGraphDataset, CoraDataset, PPIDataset, RedditDataset, TUDataset from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
ds_list = { ds_list = {
"BitcoinOTC": "BitcoinOTC()", "BitcoinOTC": "BitcoinOTC()",
"Cora": "CoraDataset()", "Cora": "CitationGraphDataset('cora')",
"Citeseer": "CitationGraphDataset('citeseer')", "Citeseer": "CitationGraphDataset('citeseer')",
"PubMed": "CitationGraphDataset('pubmed')", "PubMed": "CitationGraphDataset('pubmed')",
"QM7b": "QM7b()", "QM7b": "QM7b()",
......
...@@ -22,8 +22,8 @@ def load_dataset(name): ...@@ -22,8 +22,8 @@ def load_dataset(name):
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
g = data.graph g = data.graph
else: else:
from dgl.data import CoraDataset from dgl.data import CitationGraphDataset
data = CoraDataset() data = CitationGraphDataset('cora')
g = dgl.DGLGraph(data.graph) g = dgl.DGLGraph(data.graph)
train_mask = data.train_mask train_mask = data.train_mask
val_mask = data.val_mask val_mask = data.val_mask
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import citation_graph as citegrh from . import citation_graph as citegrh
from .citation_graph import CoraBinary, CitationGraphDataset, CoraDataset from .citation_graph import CoraBinary, CitationGraphDataset
from .minigc import * from .minigc import *
from .tree import * from .tree import *
from .utils import * from .utils import *
......
...@@ -16,7 +16,7 @@ from ..graph import DGLGraph ...@@ -16,7 +16,7 @@ from ..graph import DGLGraph
from ..graph import batch as graph_batch from ..graph import batch as graph_batch
_urls = { _urls = {
'cora' : 'dataset/cora_raw.zip', 'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip', 'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip', 'pubmed' : 'dataset/pubmed.zip',
'cora_binary' : 'dataset/cora_binary.zip', 'cora_binary' : 'dataset/cora_binary.zip',
...@@ -29,16 +29,22 @@ def _pickle_load(pkl_file): ...@@ -29,16 +29,22 @@ def _pickle_load(pkl_file):
return pkl.load(pkl_file) return pkl.load(pkl_file)
class CitationGraphDataset(object): class CitationGraphDataset(object):
r"""The citation graph dataset, including citeseer and pubmeb. r"""The citation graph dataset, including cora, citeseer and pubmeb.
Nodes mean authors and edges mean citation relationships. Nodes mean authors and edges mean citation relationships.
Parameters Parameters
----------- -----------
name: str name: str
name can be 'citeseer' or 'pubmed'. name can be 'cora', 'citeseer' or 'pubmed'.
""" """
def __init__(self, name): def __init__(self, name):
assert name.lower() in ['citeseer', 'pubmed'] assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
# for Cora, which is slightly different from the one used in the GCN paper
if name.lower() == 'cora':
name = 'cora_v2'
self.name = name self.name = name
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name) self.zip_file_path='{}/{}.zip'.format(self.dir, name)
...@@ -157,7 +163,7 @@ def _sample_mask(idx, l): ...@@ -157,7 +163,7 @@ def _sample_mask(idx, l):
return mask return mask
def load_cora(): def load_cora():
data = CoraDataset() data = CitationGraphDataset('cora')
return data return data
def load_citeseer(): def load_citeseer():
...@@ -353,67 +359,6 @@ class CoraBinary(object): ...@@ -353,67 +359,6 @@ class CoraBinary(object):
batched_labels = np.concatenate(labels, axis=0) batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels return batched_graphs, batched_pmpds, batched_labels
class CoraDataset(object):
r"""Cora citation network dataset. Nodes mean author and edges mean citation
relationships.
"""
def __init__(self):
self.name = 'cora'
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name)
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path,
'{}/{}'.format(self.dir, self.name))
self._load()
def _load(self):
idx_features_labels = np.genfromtxt("{}/cora/cora.content".
format(self.dir),
dtype=np.dtype(str))
features = sp.csr_matrix(idx_features_labels[:, 1:-1],
dtype=np.float32)
labels = _encode_onehot(idx_features_labels[:, -1])
self.num_labels = labels.shape[1]
# build graph
idx = np.asarray(idx_features_labels[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}
edges_unordered = np.genfromtxt("{}/cora/cora.cites".format(self.dir),
dtype=np.int32)
edges = np.asarray(list(map(idx_map.get, edges_unordered.flatten())),
dtype=np.int32).reshape(edges_unordered.shape)
adj = sp.coo_matrix((np.ones(edges.shape[0]),
(edges[:, 0], edges[:, 1])),
shape=(labels.shape[0], labels.shape[0]),
dtype=np.float32)
# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
self.graph = nx.from_scipy_sparse_matrix(adj, create_using=nx.DiGraph())
features = _normalize(features)
self.features = np.asarray(features.todense())
self.labels = np.where(labels)[1]
self.train_mask = _sample_mask(range(140), labels.shape[0])
self.val_mask = _sample_mask(range(200, 500), labels.shape[0])
self.test_mask = _sample_mask(range(500, 1500), labels.shape[0])
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
g = DGLGraph(self.graph)
g.ndata['train_mask'] = self.train_mask
g.ndata['val_mask'] = self.val_mask
g.ndata['test_mask'] = self.test_mask
g.ndata['label'] = self.labels
g.ndata['feat'] = self.features
return g
def __len__(self):
return 1
def _normalize(mx): def _normalize(mx):
"""Row-normalize sparse matrix""" """Row-normalize sparse matrix"""
rowsum = np.asarray(mx.sum(1)) rowsum = np.asarray(mx.sum(1))
...@@ -430,4 +375,3 @@ def _encode_onehot(labels): ...@@ -430,4 +375,3 @@ def _encode_onehot(labels):
labels_onehot = np.asarray(list(map(classes_dict.get, labels)), labels_onehot = np.asarray(list(map(classes_dict.get, labels)),
dtype=np.int32) dtype=np.int32)
return labels_onehot return labels_onehot
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment