Unverified Commit 06ea03d0 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] ICEWS18Dataset (#1913)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* ICEWS18Dataset
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent b5be4f4b
...@@ -15,7 +15,7 @@ from .karate import KarateClub, KarateClubDataset ...@@ -15,7 +15,7 @@ from .karate import KarateClub, KarateClubDataset
from .gindt import GINDataset from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .gdelt import GDELT from .gdelt import GDELT
from .icews18 import ICEWS18 from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
......
from scipy import io """ICEWS18 dataset for temporal graph"""
import numpy as np import numpy as np
import os import os
import datetime
import warnings
from .utils import get_download_dir, download, extract_archive, loadtxt from .dgl_dataset import DGLBuiltinDataset
from ..utils import retry_method_with_fix from .utils import loadtxt, _get_dgl_url, save_graphs, load_graphs
from .. import convert from ..convert import graph as dgl_graph
from .. import backend as F
class ICEWS18(object): class ICEWS18Dataset(DGLBuiltinDataset):
""" r""" ICEWS18 dataset for temporal graph
Integrated Crisis Early Warning System (ICEWS18) Integrated Crisis Early Warning System (ICEWS18)
Event data consists of coded interactions between socio-political Event data consists of coded interactions between socio-political
actors (i.e., cooperative or hostile actions between individuals, actors (i.e., cooperative or hostile actions between individuals,
groups, sectors and nation states). groups, sectors and nation states). This Dataset consists of events
This Dataset consists of events from 1/1/2018 from 1/1/2018 to 10/31/2018 (24 hours time granularity).
to 10/31/2018 (24 hours time granularity).
Reference: Reference:
- `Recurrent Event Network for Reasoning over Temporal - `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_ Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `ICEWS Coded Event Data <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_ - `ICEWS Coded Event Data
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Statistics
----------
Train examples: 240
Valid examples: 30
Test examples: 34
Nodes per graph: 23033
Parameters Parameters
------------ ----------
mode: str mode: str
Load train/valid/test data. Has to be one of ['train', 'valid', 'test'] Load train/valid/test data. Has to be one of ['train', 'valid', 'test']
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
-------
is_temporal : bool
Is the dataset contains temporal graphs
Examples
--------
>>> # get train, valid, test set
>>> train_data = ICEWS18Dataset()
>>> valid_data = ICEWS18Dataset(mode='valid')
>>> test_data = ICEWS18Dataset(mode='test')
>>>
>>> train_size = len(train_data)
>>> for g in train_data:
.... e_feat = g.edata['rel_type']
.... # your code here
....
>>>
""" """
_url = { def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
'train': 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18/train.txt', mode = mode.lower()
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18/valid.txt', assert mode in ['train', 'valid', 'test'], "Mode not valid"
'test': 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18/test.txt',
}
def __init__(self, mode):
assert mode.lower() in self._url, "Mode not valid"
self.dir = get_download_dir()
self.mode = mode self.mode = mode
self.graphs = [] _url = _get_dgl_url('dataset/icews18.zip')
train_data = loadtxt(os.path.join( super(ICEWS18Dataset, self).__init__(name='ICEWS18',
self.dir, 'ICEWS18', 'train.txt'), delimiter='\t').astype(np.int64) url=_url,
if self.mode == 'train': raw_dir=raw_dir,
self._load(train_data) force_reload=force_reload,
elif self.mode == 'valid': verbose=verbose)
val_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'valid.txt'), delimiter='\t').astype(np.int64) def process(self):
train_data[:, 3] = -1 data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)),
self._load(np.concatenate([train_data, val_data], axis=0)) delimiter='\t').astype(np.int64)
elif self.mode == 'test':
val_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'valid.txt'), delimiter='\t').astype(np.int64)
test_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'test.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
val_data[:, 3] = -1
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'ICEWS18', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data):
num_nodes = 23033 num_nodes = 23033
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples. # totally 137 samples. The cutoff below has exactly 137 samples.
time_index = np.floor(data[:, 3]/24).astype(np.int64) time_index = np.floor(data[:, 3] / 24).astype(np.int64)
start_time = time_index[time_index != -1].min() start_time = time_index[time_index != -1].min()
end_time = time_index.max() end_time = time_index.max()
for i in range(start_time, end_time+1): self._graphs = []
for i in range(start_time, end_time + 1):
row_mask = time_index <= i row_mask = time_index <= i
edges = data[row_mask][:, [0, 2]] edges = data[row_mask][:, [0, 2]]
rate = data[row_mask][:, 1] rate = data[row_mask][:, 1]
g = convert.graph((edges[:, 0], edges[:, 1])) g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = rate.reshape(-1, 1) g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
self.graphs.append(g) self._graphs.append(g)
def has_cache(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
return os.path.exists(graph_path)
def save(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
save_graphs(graph_path, self._graphs)
def load(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
self._graphs = load_graphs(graph_path)[0]
def __getitem__(self, idx): def __getitem__(self, idx):
return self.graphs[idx] r""" Get graph by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
graph structure and edge feature
- edata['rel_type']: edge type
"""
return self._graphs[idx]
def __len__(self): def __len__(self):
return len(self.graphs) r"""Number of graphs in the dataset"""
return len(self._graphs)
@property @property
def is_temporal(self): def is_temporal(self):
r"""Is the dataset contains temporal graphs
Returns
-------
bool
"""
return True return True
ICEWS18 = ICEWS18Dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment