Unverified Commit 9c790b11 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

[Dataset] Add CoraFull, Amazon, KarateClub, Coauthor Dataset (#855)

* convert np.ndarray to backend tensor

* add datasets

* add qm7

* add dataset

* add dataset

* fix

* change ppi

* tu dataset

* add datasets

* fix

* fix

* fix

* fix

* add docstring

* docs

* doc
parent 33b8700b
......@@ -34,6 +34,78 @@ For more information about the dataset, see `Sentiment Analysis <https://nlp.sta
.. autoclass:: SST
:members: __getitem__, __len__
Karate Club dataset
```````````````````````````````````
.. autoclass:: KarateClub
:members: __getitem__, __len__
Citation Network dataset
```````````````````````````````````
.. autoclass:: CitationGraphDataset
:members: __getitem__, __len__
Cora Citation Network dataset
```````````````````````````````````
.. autoclass:: CoraDataset
:members: __getitem__, __len__
CoraFull dataset
```````````````````````````````````
.. autoclass:: CoraFull
:members: __getitem__, __len__
Amazon Co-Purchase dataset
```````````````````````````````````
.. autoclass:: AmazonCoBuy
:members: __getitem__, __len__
Coauthor dataset
```````````````````````````````````
.. autoclass:: Coauthor
:members: __getitem__, __len__
BitcoinOTC dataset
```````````````````````````````````
.. autoclass:: BitcoinOTC
:members: __getitem__, __len__
ICEWS18 dataset
```````````````````````````````````
.. autoclass:: ICEWS18
:members: __getitem__, __len__
QM7b dataset
```````````````````````````````````
.. autoclass:: QM7b
:members: __getitem__, __len__
GDELT dataset
```````````````````````````````````
.. autoclass:: GDELT
:members: __getitem__, __len__
Mini graph classification dataset
`````````````````````````````````
......
Dataset (Temporary)
.. table::
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
| Datset Name | Usage |# of graphs|Avg. # of nodes|Avg. # of edges| Node field |Edge field |Temporal|
+================+==========================================================+===========+===============+===============+============================================+===========+========+
|BitcoinOTC |BitcoinOTC() | 136| 6005.00| 21209.98| |h |True |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Cora |CoraDataset() | 1| 2708.00| 10556.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Citeseer |CitationGraphDataset('citeseer') | 1| 3327.00| 9228.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|PubMed |CitationGraphDataset('pubmed') | 1| 19717.00| 88651.00|train_mask, val_mask, test_mask, label, feat| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|QM7b |QM7b() | 7211| 15.42| 244.95| |h |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Reddit |RedditDataset() | 1| 232965.00| 114615892.00|train_mask, val_mask, test_mask, feat, label| |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|ENZYMES |TUDataset('ENZYMES') | 600| 32.63| 124.27|node_labels, node_attr | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|DD |TUDataset('DD') | 1178| 284.32| 1431.32|node_labels | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|COLLAB |TUDataset('COLLAB') | 5000| 74.49| 9830.00| | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|MUTAG |TUDataset('MUTAG') | 188| 17.93| 39.59|node_labels |edge_labels|False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|PROTEINS |TUDataset('PROTEINS') | 1113| 39.06| 145.63|node_labels, node_attr | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|PPI |PPIDataset('train')/PPIDataset('valid')/PPIDataset('test')| 20| 2245.30| 63563.70|feat | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|KarateClub |KarateClub() | 1| 34.00| 156.00|label | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Amazon computer |AmazonCoBuy('computers') | 1| 13752.00| 574418.00|feat, label | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Amazon photo |AmazonCoBuy('photo') | 1| 7650.00| 287326.00|feat, label | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Coauthor cs |Coauthor('cs') | 1| 18333.00| 327576.00|feat, label | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|Coauthor physics|Coauthor('physics') | 1| 34493.00| 991848.00|feat, label | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|GDELT |GDELT('train')/GDELT('valid')/GDELT('test') | 2304| 23033.00| 811333.15| |rel_type |True |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|ICEWS18 |ICEWS18('train')/ICEWS18('valid')/ICEWS18('test') | 240| 23033.00| 192640.22| |rel_type |True |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
|CoraFull |CoraFull() | 1| 19793.00| 130622.00|feat, label | |False |
+----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+
\ No newline at end of file
from pytablewriter import RstGridTableWriter, MarkdownTableWriter
import numpy as np
import pandas as pd
from dgl import DGLGraph
from dgl.data.gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor
from dgl.data.karate import KarateClub
from dgl.data.gindt import GINDataset
from dgl.data.bitcoinotc import BitcoinOTC
from dgl.data.gdelt import GDELT
from dgl.data.icews18 import ICEWS18
from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9
from dgl.data import CitationGraphDataset, CoraDataset, PPIDataset, RedditDataset, TUDataset
ds_list = {
"BitcoinOTC": "BitcoinOTC()",
"Cora": "CoraDataset()",
"Citeseer": "CitationGraphDataset('citeseer')",
"PubMed": "CitationGraphDataset('pubmed')",
"QM7b": "QM7b()",
"Reddit": "RedditDataset()",
"ENZYMES": "TUDataset('ENZYMES')",
"DD": "TUDataset('DD')",
"COLLAB": "TUDataset('COLLAB')",
"MUTAG": "TUDataset('MUTAG')",
"PROTEINS": "TUDataset('PROTEINS')",
"PPI": "PPIDataset('train')/PPIDataset('valid')/PPIDataset('test')",
# "Cora Binary": "CitationGraphDataset('cora_binary')",
"KarateClub": "KarateClub()",
"Amazon computer": "AmazonCoBuy('computers')",
"Amazon photo": "AmazonCoBuy('photo')",
"Coauthor cs": "Coauthor('cs')",
"Coauthor physics": "Coauthor('physics')",
"GDELT": "GDELT('train')/GDELT('valid')/GDELT('test')",
"ICEWS18": "ICEWS18('train')/ICEWS18('valid')/ICEWS18('test')",
"CoraFull": "CoraFull()",
}
writer = RstGridTableWriter()
# writer = MarkdownTableWriter()
extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0]
stat_list=[]
for k,v in ds_list.items():
print(k, ' ', v)
ds = eval(v.split("/")[0])
num_nodes = []
num_edges = []
for i in range(len(ds)):
g = extract_graph(ds[i])
num_nodes.append(g.number_of_nodes())
num_edges.append(g.number_of_edges())
gg = extract_graph(ds[0])
dd = {
"Datset Name": k,
"Usage": v,
"# of graphs": len(ds),
"Avg. # of nodes": np.mean(num_nodes),
"Avg. # of edges": np.mean(num_edges),
"Node field": ', '.join(list(gg.ndata.keys())),
"Edge field": ', '.join(list(gg.edata.keys())),
# "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "",
"Temporal": hasattr(ds, "is_temporal")
}
stat_list.append(dd)
print(dd.keys())
df = pd.DataFrame(stat_list)
df = df.reindex(columns=dd.keys())
writer.from_dataframe(df)
writer.write_table()
......@@ -123,7 +123,7 @@ def graph_classify_task(prog_args):
perform graph classification task
'''
dataset = tu.TUDataset(name=prog_args.dataset)
dataset = tu.LegacyTUDataset(name=prog_args.dataset)
train_size = int(prog_args.train_ratio * len(dataset))
test_size = int(prog_args.test_ratio * len(dataset))
val_size = int(len(dataset) - train_size - test_size)
......
......@@ -17,7 +17,7 @@ import torch.nn.functional as F
import argparse
from sklearn.metrics import f1_score
from gat import GAT
from dgl.data.ppi import PPIDataset
from dgl.data.ppi import LegacyPPIDataset
from torch.utils.data import DataLoader
def collate(sample):
......@@ -54,9 +54,9 @@ def main(args):
# define loss function
loss_fcn = torch.nn.BCEWithLogitsLoss()
# create the dataset
train_dataset = PPIDataset(mode='train')
valid_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')
train_dataset = LegacyPPIDataset(mode='train')
valid_dataset = LegacyPPIDataset(mode='valid')
test_dataset = LegacyPPIDataset(mode='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)
......
......@@ -2,15 +2,21 @@
from __future__ import absolute_import
from . import citation_graph as citegrh
from .citation_graph import CoraBinary
from .citation_graph import CoraBinary, CitationGraphDataset, CoraDataset
from .minigc import *
from .tree import *
from .utils import *
from .sbm import SBMMixture
from .reddit import RedditDataset
from .ppi import PPIDataset
from .tu import TUDataset
from .ppi import PPIDataset, LegacyPPIDataset
from .tu import TUDataset, LegacyTUDataset
from .gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor
from .karate import KarateClub
from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC
from .gdelt import GDELT
from .icews18 import ICEWS18
from .qm7b import QM7b
def register_data_args(parser):
......
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
import datetime
from .utils import get_download_dir, download, extract_archive
class BitcoinOTC(object):
"""
This is who-trusts-whom network of people who trade using Bitcoin
on a platform called Bitcoin OTC.
Since Bitcoin users are anonymous, there is a need to maintain a
record of users' reputation to prevent transactions with fraudulent
and risky users. Members of Bitcoin OTC rate other members in a
scale of -10 (total distrust) to +10 (total trust) in steps of 1.
Reference:
- `Bitcoin OTC trust weighted signed network <http://snap.stanford.edu/data/soc-sign-bitcoin-otc.html>`_
- `EvolveGCN: Evolving Graph
Convolutional Networks for Dynamic Graphs
<https://arxiv.org/abs/1902.10191>`_
"""
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
def __init__(self):
self.dir = get_download_dir()
self.zip_path = os.path.join(
self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv.gz")
download(self._url, path=self.zip_path)
extract_archive(self.zip_path, os.path.join(
self.dir, 'bitcoin'))
self.path = os.path.join(
self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv")
self.graphs = []
self._load(self.path)
def _load(self, filename):
data = np.loadtxt(filename, delimiter=',').astype(np.int64)
data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min()
num_nodes = data[:, 0:2].max() - data[:, 0:2].min() + 1
delta = datetime.timedelta(days=14).total_seconds()
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
time_index = np.around(
(data[:, 3] - data[:, 3].min())/delta).astype(np.int64)
for i in range(time_index.max()):
g = DGLGraph()
g.add_nodes(num_nodes)
row_mask = time_index <= i
edges = data[row_mask][:, 0:2]
rate = data[row_mask][:, 2]
g.add_edges(edges[:, 0], edges[:, 1])
g.edata['h'] = rate.reshape(-1, 1)
self.graphs.append(g)
def __getitem__(self, idx):
return self.graphs[idx]
def __len__(self):
return len(self.graphs)
@property
def is_temporal(self):
return True
\ No newline at end of file
......@@ -10,6 +10,7 @@ import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import os, sys
from dgl import DGLGraph
import dgl
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
......@@ -28,7 +29,16 @@ def _pickle_load(pkl_file):
return pkl.load(pkl_file)
class CitationGraphDataset(object):
r"""The citation graph dataset, including citeseer and pubmeb.
Nodes mean authors and edges mean citation relationships.
Parameters
-----------
name: str
name can be 'citeseer' or 'pubmed'.
"""
def __init__(self, name):
assert name.lower() in ['citeseer', 'pubmed']
self.name = name
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
......@@ -112,7 +122,14 @@ class CitationGraphDataset(object):
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx):
return self
assert idx == 0, "This dataset has only one graph"
g = DGLGraph(self.graph)
g.ndata['train_mask'] = self.train_mask
g.ndata['val_mask'] = self.val_mask
g.ndata['test_mask'] = self.test_mask
g.ndata['label'] = self.labels
g.ndata['feat'] = self.features
return g
def __len__(self):
return 1
......@@ -337,6 +354,9 @@ class CoraBinary(object):
return batched_graphs, batched_pmpds, batched_labels
class CoraDataset(object):
r"""Cora citation network dataset. Nodes mean author and edges mean citation
relationships.
"""
def __init__(self):
self.name = 'cora'
self.dir = get_download_dir()
......@@ -378,6 +398,21 @@ class CoraDataset(object):
self.train_mask = _sample_mask(range(140), labels.shape[0])
self.val_mask = _sample_mask(range(200, 500), labels.shape[0])
self.test_mask = _sample_mask(range(500, 1500), labels.shape[0])
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
g = DGLGraph(self.graph)
g.ndata['train_mask'] = self.train_mask
g.ndata['val_mask'] = self.val_mask
g.ndata['test_mask'] = self.test_mask
g.ndata['label'] = self.labels
g.ndata['feat'] = self.features
return g
def __len__(self):
return 1
def _normalize(mx):
"""Row-normalize sparse matrix"""
......
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt
class GDELT(object):
"""
The Global Database of Events, Language, and Tone (GDELT) dataset.
This contains events happend all over the world (ie every protest held anywhere
in Russia on a given day is collapsed to a single entry).
This Dataset consists of
events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity).
Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Parameters
------------
mode: str
Load train/valid/test data. Has to be one of ['train', 'valid', 'test']
"""
_url = {
'train': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/train.txt',
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/valid.txt',
'test': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/test.txt',
}
def __init__(self, mode):
assert mode.lower() in self._url, "Mode not valid"
self.dir = get_download_dir()
self.mode = mode
# self.graphs = []
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
train_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train':
self._load(train_data)
elif self.mode == 'valid':
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
self._load(np.concatenate([train_data, val_data], axis=0))
elif self.mode == 'test':
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
test_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'test.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
val_data[:, 3] = -1
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _load(self, data):
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
self.data = data
self.time_index = np.floor(data[:, 3]/15).astype(np.int64)
self.start_time = self.time_index[self.time_index != -1].min()
self.end_time = self.time_index.max()
def __getitem__(self, idx):
if idx >= len(self) or idx < 0:
raise IndexError("Index out of range")
i = idx + self.start_time
g = DGLGraph()
g.add_nodes(self.num_nodes)
row_mask = self.time_index <= i
edges = self.data[row_mask][:, [0, 2]]
rate = self.data[row_mask][:, 1]
g.add_edges(edges[:, 0], edges[:, 1])
g.edata['rel_type'] = rate.reshape(-1, 1)
return g
def __len__(self):
return self.end_time - self.start_time + 1
@property
def num_nodes(self):
return 23033
@property
def is_temporal(self):
return True
import scipy.sparse as sp
import numpy as np
from dgl import graph_index, DGLGraph, transform
import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
__all__=["AmazonCoBuy", "Coauthor", 'CoraFull']
def eliminate_self_loops(A):
"""Remove self-loops from the adjacency matrix."""
A = A.tolil()
A.setdiag(0)
A = A.tocsr()
A.eliminate_zeros()
return A
class GNNBenchmarkDataset(object):
"""Base Class for GNN Benchmark dataset from https://github.com/shchur/gnn-benchmark#datasets"""
_url = {}
def __init__(self, name):
assert name.lower() in self._url, "Name not valid"
self.dir = get_download_dir()
self.path = os.path.join(
self.dir, 'gnn_benckmark', self._url[name.lower()].split('/')[-1])
download(self._url[name.lower()], path=self.path)
g = self.load_npz(self.path)
self.data = [g]
@staticmethod
def load_npz(file_name):
with np.load(file_name) as loader:
loader = dict(loader)
num_nodes = loader['adj_shape'][0]
adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']),
shape=loader['adj_shape']).tocoo()
if 'attr_data' in loader:
# Attributes are stored as a sparse CSR matrix
attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']),
shape=loader['attr_shape']).todense()
elif 'attr_matrix' in loader:
# Attributes are stored as a (dense) np.ndarray
attr_matrix = loader['attr_matrix']
else:
attr_matrix = None
if 'labels_data' in loader:
# Labels are stored as a CSR matrix
labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']),
shape=loader['labels_shape']).todense()
elif 'labels' in loader:
# Labels are stored as a numpy array
labels = loader['labels']
else:
labels = None
g = DGLGraph()
g.add_nodes(num_nodes)
g.add_edges(adj_matrix.row, adj_matrix.col)
g.add_edges(adj_matrix.col, adj_matrix.row)
g.ndata['feat'] = attr_matrix
g.ndata['label'] = labels
return g
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self.data[0]
def __len__(self):
return len(self.data)
class CoraFull(GNNBenchmarkDataset):
r"""
Extended Cora dataset from `Deep Gaussian Embedding of Graphs:
Unsupervised Inductive Learning via Ranking`. Nodes represent paper and edges represent citations.
Reference: https://github.com/shchur/gnn-benchmark#datasets
"""
_url = {"cora_full":'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/cora_full.npz'}
def __init__(self):
super().__init__("cora_full")
class Coauthor(GNNBenchmarkDataset):
r"""
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge 3
. Here, nodes are authors, that are connected by an edge if they
co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author.
Parameters
---------------
name: str
Name of the dataset, has to be 'cs' or 'physics'
"""
_url = {
'cs': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/ms_academic_cs.npz",
'physics': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/ms_academic_phy.npz"
}
class AmazonCoBuy(GNNBenchmarkDataset):
r"""
Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley
et al., 2015], where nodes represent goods, edges indicate that two goods are frequently bought
together, node features are bag-of-words encoded product reviews, and class labels are given by the
product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Parameters
---------------
name: str
Name of the dataset, has to be 'computer' or 'photo'
"""
_url = {
'computers': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_computers.npz",
'photo': "https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_photo.npz"
}
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
import datetime
import warnings
from .utils import get_download_dir, download, extract_archive, loadtxt
class ICEWS18(object):
"""
Integrated Crisis Early Warning System (ICEWS18)
Event data consists of coded interactions between socio-political
actors (i.e., cooperative or hostile actions between individuals,
groups, sectors and nation states).
This Dataset consists of events from 1/1/2018
to 10/31/2018 (24 hours time granularity).
Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `ICEWS Coded Event Data <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Parameters
------------
mode: str
Load train/valid/test data. Has to be one of ['train', 'valid', 'test']
"""
_url = {
'train': 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18/train.txt',
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18/valid.txt',
'test': 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18/test.txt',
}
def __init__(self, mode):
assert mode.lower() in self._url, "Mode not valid"
self.dir = get_download_dir()
self.mode = mode
self.graphs = []
for dname in self._url:
dpath = os.path.join(
self.dir, 'ICEWS18', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
train_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train':
self._load(train_data)
elif self.mode == 'valid':
val_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'valid.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
self._load(np.concatenate([train_data, val_data], axis=0))
elif self.mode == 'test':
val_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'valid.txt'), delimiter='\t').astype(np.int64)
test_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'test.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
val_data[:, 3] = -1
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _load(self, data):
num_nodes = 23033
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
time_index = np.floor(data[:, 3]/24).astype(np.int64)
start_time = time_index[time_index != -1].min()
end_time = time_index.max()
for i in range(start_time, end_time+1):
g = DGLGraph()
g.add_nodes(num_nodes)
row_mask = time_index <= i
edges = data[row_mask][:, [0, 2]]
rate = data[row_mask][:, 1]
g.add_edges(edges[:, 0], edges[:, 1])
g.edata['rel_type'] = rate.reshape(-1, 1)
self.graphs.append(g)
def __getitem__(self, idx):
return self.graphs[idx]
def __len__(self):
return len(self.graphs)
@property
def is_temporal(self):
return True
\ No newline at end of file
"""KarateClub Dataset
"""
import numpy as np
import networkx as nx
from dgl import DGLGraph
class KarateClub(object):
"""
Zachary's karate club is a social network of a university karate club, described in the paper
"An Information Flow Model for Conflict and Fission in Small Groups" by Wayne W. Zachary. The
network became a popular example of community structure in networks after its use by Michelle
Girvan and Mark Newman in 2002.
This dataset has only one graph, with ndata 'label' means whether the node is belong to the "Mr. Hi" club.
"""
def __init__(self):
kG = nx.karate_club_graph()
self.label = np.array(
[kG.node[i]['club'] != 'Mr. Hi' for i in kG.nodes]).astype(np.int64)
g = DGLGraph(kG)
g.ndata['label'] = self.label
self.data = [g]
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self.data[0]
def __len__(self):
return len(self.data)
......@@ -11,6 +11,7 @@ from ..graph import DGLGraph
_url = 'dataset/ppi.zip'
class PPIDataset(object):
"""A toy Protein-Protein Interaction network dataset.
......@@ -29,6 +30,7 @@ class PPIDataset(object):
mode : str
('train', 'valid', 'test').
"""
assert mode in ['train', 'valid', 'test']
self.mode = mode
self._load()
self._preprocess()
......@@ -117,6 +119,38 @@ class PPIDataset(object):
if self.mode == 'test':
return len(self.test_mask_list)
def __getitem__(self, item):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, ndarray)
The graph, and its label.
"""
if self.mode == 'train':
g = self.train_graphs[item]
g.ndata['feat'] = self.features[self.train_mask_list[item]]
label = self.train_labels[item]
elif self.mode == 'valid':
g = self.valid_graphs[item]
g.ndata['feat'] = self.features[self.valid_mask_list[item]]
label = self.valid_labels[item]
elif self.mode == 'test':
g = self.test_graphs[item]
g.ndata['feat'] = self.features[self.test_mask_list[item]]
label = self.test_labels[item]
return g, label
class LegacyPPIDataset(PPIDataset):
"""Legacy version of PPI Dataset
"""
def __getitem__(self, item):
"""Get the i^th sample.
......@@ -135,4 +169,4 @@ class PPIDataset(object):
if self.mode == 'valid':
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
if self.mode == 'test':
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
\ No newline at end of file
from scipy import io
import numpy as np
from dgl import DGLGraph
import os
from .utils import get_download_dir, download
class QM7b(object):
"""
This dataset consists of 7,211 molecules with 14 regression targets.
Nodes means atoms and edges means bonds. Edge data 'h' means
the entry of Coulomb matrix.
Reference:
- `QM7b Dataset <http://quantum-machine.org/datasets/>`_
"""
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
def __init__(self):
self.dir = get_download_dir()
self.path = os.path.join(self.dir, 'qm7b', "qm7b.mat")
download(self._url, path=self.path)
self.graphs = []
self._load(self.path)
def _load(self, filename):
data = io.loadmat(self.path)
labels = data['T']
feats = data['X']
num_graphs = labels.shape[0]
self.label = labels
for i in range(num_graphs):
g = DGLGraph()
edge_list = feats[i].nonzero()
num_nodes = np.max(edge_list) + 1
g.add_nodes(num_nodes)
g.add_edges(edge_list[0], edge_list[1])
g.edata['h'] = feats[i][edge_list[0], edge_list[1]].reshape(-1, 1)
self.graphs.append(g)
def __getitem__(self, idx):
return self.graphs[idx], self.label[idx]
def __len__(self):
return len(self.graphs)
......@@ -41,3 +41,15 @@ class RedditDataset(object):
print(' NumValidationSamples: {}'.format(len(np.nonzero(self.val_mask)[0])))
print(' NumTestSamples: {}'.format(len(np.nonzero(self.test_mask)[0])))
def __getitem__(self, idx):
assert idx == 0, "Reddit Dataset only has one graph"
g = self.graph
g.ndata['train_mask'] = self.train_mask
g.ndata['val_mask'] = self.val_mask
g.ndata['test_mask'] = self.test_mask
g.ndata['feat'] = self.features
g.ndata['label'] = self.labels
return g
def __len__(self):
return 1
......@@ -4,10 +4,10 @@ import dgl
import os
import random
from dgl.data.utils import download, extract_archive, get_download_dir
from dgl.data.utils import download, extract_archive, get_download_dir, loadtxt
class TUDataset(object):
class LegacyTUDataset(object):
"""
TUDataset contains lots of graph kernel datasets for graph classification.
Use provided node feature by default. If no feature provided, use one-hot node label instead.
......@@ -154,3 +154,106 @@ class TUDataset(object):
return self.graph_lists[0].ndata['feat'].shape[1],\
self.num_labels,\
self.max_num_node
class TUDataset(object):
"""
TUDataset contains lots of graph kernel datasets for graph classification.
Graphs may have node labels, node attributes, edge labels, and edge attributes,
varing from different dataset.
:param name: Dataset Name, such as `ENZYMES`, `DD`, `COLLAB`, `MUTAG`, can be the
datasets name on https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets.
"""
_url = r"https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{}.zip"
def __init__(self, name):
self.name = name
self.extract_dir = self._download()
DS_edge_list = self._idx_from_zero(
loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
DS_graph_labels = self._idx_from_zero(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
g = dgl.DGLGraph()
g.add_nodes(int(DS_edge_list.max()) + 1)
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
node_idx_list = []
self.max_num_node = 0
for idx in range(np.max(DS_indicator) + 1):
node_idx = np.where(DS_indicator == idx)
node_idx_list.append(node_idx[0])
if len(node_idx[0]) > self.max_num_node:
self.max_num_node = len(node_idx[0])
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
self.attr_dict = {
'node_labels': ('ndata', 'node_labels'),
'node_attributes': ('ndata', 'node_attr'),
'edge_labels': ('edata', 'edge_labels'),
'edge_attributes': ('edata', 'node_labels'),
}
for filename, field_name in self.attr_dict.items():
try:
data = loadtxt(self._file_path(filename),
delimiter=',').astype(int)
if 'label' in filename:
data = self._idx_from_zero(data)
getattr(g, field_name[0])[field_name[1]] = data
except IOError:
pass
self.graph_lists = g.subgraphs(node_idx_list)
for g in self.graph_lists:
g.copy_from_parent()
def __getitem__(self, idx):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, int)
DGLGraph with node feature stored in `feat` field and node label in `node_label` if available.
And its label.
"""
g = self.graph_lists[idx]
return g, self.graph_labels[idx]
def __len__(self):
return len(self.graph_lists)
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
download(self._url.format(self.name), path=zip_file_path)
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self, category):
return os.path.join(self.extract_dir, self.name,
"{}_{}.txt".format(self.name, category))
@staticmethod
def _idx_from_zero(idx_tensor):
return idx_tensor - np.min(idx_tensor)
def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1], \
self.num_labels, \
self.max_num_node
......@@ -8,6 +8,7 @@ import warnings
import zipfile
import tarfile
import numpy as np
import warnings
from .graph_serialize import save_graphs, load_graphs, load_labels
......@@ -18,10 +19,19 @@ except ImportError:
pass
requests = requests_failed_to_import
__all__ = ['download', 'check_sha1', 'extract_archive',
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset',
'save_graphs', "load_graphs", "load_labels"]
def loadtxt(path, delimiter, dtype=None):
try:
import pandas as pd
df = pd.read_csv(path, delimiter=delimiter, header=None)
return df.values
except ImportError:
warnings.warn("Pandas is not installed, now using numpy.loadtxt to load data, "
"which could be extremely slow. Accelerate by installing pandas")
return np.loadtxt(path, delimiter=delimiter)
def _get_dgl_url(file_url):
"""Get DGL online url for download."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment