Unverified Commit 73b9c6f1 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] GDELTDataset (#1911)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* gdelt dataset

* Update gdelt.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 06ea03d0
......@@ -14,7 +14,7 @@ from .gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor
from .karate import KarateClub, KarateClubDataset
from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .gdelt import GDELT
from .gdelt import GDELT, GDELTDataset
from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset
......
from scipy import io
""" GDELT dataset for temporal graph """
import numpy as np
import os
import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt
from ..utils import retry_method_with_fix
from .. import convert
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_info, load_info, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F
class GDELT(object):
"""
The Global Database of Events, Language, and Tone (GDELT) dataset.
This contains events happend all over the world (ie every protest held anywhere
in Russia on a given day is collapsed to a single entry).
class GDELTDataset(DGLBuiltinDataset):
r"""GDELT dataset for event-based temporal graph
This Dataset consists of
events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity).
The Global Database of Events, Language, and Tone (GDELT) dataset.
This contains events happend all over the world (ie every protest held
anywhere in Russia on a given day is collapsed to a single entry).
This Dataset consists ofevents collected from 1/1/2018 to 1/31/2018
(15 minutes time granularity).
Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs` <https://arxiv.org/abs/1904.05530>
- `The Global Database of Events, Language, and Tone (GDELT) `
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>
Statistics
----------
Train examples: 2,304
Valid examples: 288
Test examples: 384
Parameters
------------
mode: str
Load train/valid/test data. Has to be one of ['train', 'valid', 'test']
----------
mode : str
Must be one of ('train', 'valid', 'test'). Default: 'train'
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
start_time : int
Start time of the temporal graph
end_time : int
End time of the temporal graph
is_temporal : bool
Does the dataset contain temporal graphs
Examples
----------
>>> # get train, valid, test dataset
>>> train_data = GDELTDataset()
>>> valid_data = GDELTDataset(mode='valid')
>>> test_data = GDELTDataset(mode='test')
>>>
>>> # length of train set
>>> train_size = len(train_data)
>>>
>>> for g in train_data:
.... e_feat = g.edata['rel_type']
.... # your code here
....
>>>
"""
_url = {
'train': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/train.txt',
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/valid.txt',
'test': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/test.txt',
}
def __init__(self, mode):
assert mode.lower() in self._url, "Mode not valid"
self.dir = get_download_dir()
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid."
self.mode = mode
# self.graphs = []
train_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train':
self._load(train_data)
elif self.mode == 'valid':
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
self._load(np.concatenate([train_data, val_data], axis=0))
elif self.mode == 'test':
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
test_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'test.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
val_data[:, 3] = -1
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data):
self.num_nodes = 23033
_url = _get_dgl_url('dataset/gdelt.zip')
super(GDELTDataset, self).__init__(name='GDELT',
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
file_path = os.path.join(self.raw_path, self.mode + '.txt')
self.data = loadtxt(file_path, delimiter='\t').astype(np.int64)
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
self.data = data
self.time_index = np.floor(data[:, 3]/15).astype(np.int64)
self.start_time = self.time_index[self.time_index != -1].min()
self.end_time = self.time_index.max()
self.time_index = np.floor(self.data[:, 3] / 15).astype(np.int64)
self._start_time = self.time_index.min()
self._end_time = self.time_index.max()
def has_cache(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
return os.path.exists(info_path)
def save(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
save_info(info_path, {'data': self.data,
'time_index': self.time_index,
'start_time': self.start_time,
'end_time': self.end_time})
def load(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
info = load_info(info_path)
self.data, self.time_index, self._start_time, self._end_time = \
info['data'], info['time_index'], info['start_time'], info['end_time']
@property
def start_time(self):
r""" Start time of events in the temporal graph
def __getitem__(self, idx):
if idx >= len(self) or idx < 0:
Returns
-------
int
"""
return self._start_time
@property
def end_time(self):
r""" End time of events in the temporal graph
Returns
-------
int
"""
return self._end_time
def __getitem__(self, t):
r""" Get graph by with events before time `t + self.start_time`
Parameters
----------
t : int
Time, its value must be in range [0, `self.end_time` - `self.start_time`]
Returns
-------
dgl.DGLGraph
graph structure and edge feature
- edata['rel_type']: edge type
"""
if t >= len(self) or t < 0:
raise IndexError("Index out of range")
i = idx + self.start_time
i = t + self.start_time
row_mask = self.time_index <= i
edges = self.data[row_mask][:, [0, 2]]
rate = self.data[row_mask][:, 1]
g = convert.graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = rate.reshape(-1, 1)
g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
return g
def __len__(self):
return self.end_time - self.start_time + 1
@property
def num_nodes(self):
return 23033
r"""Number of graphs in the dataset"""
return self._end_time - self._start_time + 1
@property
def is_temporal(self):
r""" Does the dataset contain temporal graphs
Returns
-------
bool
"""
return True
GDELT = GDELTDataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment