Unverified Commit 73b9c6f1 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] GDELTDataset (#1911)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* gdelt dataset

* Update gdelt.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 06ea03d0
...@@ -14,7 +14,7 @@ from .gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor ...@@ -14,7 +14,7 @@ from .gnn_benckmark import AmazonCoBuy, CoraFull, Coauthor
from .karate import KarateClub, KarateClubDataset from .karate import KarateClub, KarateClubDataset
from .gindt import GINDataset from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .gdelt import GDELT from .gdelt import GDELT, GDELTDataset
from .icews18 import ICEWS18, ICEWS18Dataset from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset from .dgl_dataset import DGLDataset, DGLBuiltinDataset
......
from scipy import io """ GDELT dataset for temporal graph """
import numpy as np import numpy as np
import os import os
import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt from .dgl_dataset import DGLBuiltinDataset
from ..utils import retry_method_with_fix from .utils import loadtxt, save_info, load_info, _get_dgl_url
from .. import convert from ..convert import graph as dgl_graph
from .. import backend as F
class GDELT(object): class GDELTDataset(DGLBuiltinDataset):
""" r"""GDELT dataset for event-based temporal graph
The Global Database of Events, Language, and Tone (GDELT) dataset.
This contains events happend all over the world (ie every protest held anywhere
in Russia on a given day is collapsed to a single entry).
This Dataset consists of The Global Database of Events, Language, and Tone (GDELT) dataset.
events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity). This contains events happend all over the world (ie every protest held
anywhere in Russia on a given day is collapsed to a single entry).
This Dataset consists ofevents collected from 1/1/2018 to 1/31/2018
(15 minutes time granularity).
Reference: Reference:
- `Recurrent Event Network for Reasoning over Temporal - `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_ Knowledge Graphs` <https://arxiv.org/abs/1904.05530>
- `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_ - `The Global Database of Events, Language, and Tone (GDELT) `
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>
Statistics
----------
Train examples: 2,304
Valid examples: 288
Test examples: 384
Parameters Parameters
------------ ----------
mode: str mode : str
Load train/valid/test data. Has to be one of ['train', 'valid', 'test'] Must be one of ('train', 'valid', 'test'). Default: 'train'
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
start_time : int
Start time of the temporal graph
end_time : int
End time of the temporal graph
is_temporal : bool
Does the dataset contain temporal graphs
Examples
----------
>>> # get train, valid, test dataset
>>> train_data = GDELTDataset()
>>> valid_data = GDELTDataset(mode='valid')
>>> test_data = GDELTDataset(mode='test')
>>>
>>> # length of train set
>>> train_size = len(train_data)
>>>
>>> for g in train_data:
.... e_feat = g.edata['rel_type']
.... # your code here
....
>>>
""" """
_url = { def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
'train': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/train.txt', mode = mode.lower()
'valid': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/valid.txt', assert mode in ['train', 'valid', 'test'], "Mode not valid."
'test': 'https://github.com/INK-USC/RENet/raw/master/data/GDELT/test.txt',
}
def __init__(self, mode):
assert mode.lower() in self._url, "Mode not valid"
self.dir = get_download_dir()
self.mode = mode self.mode = mode
# self.graphs = [] self.num_nodes = 23033
train_data = loadtxt(os.path.join( _url = _get_dgl_url('dataset/gdelt.zip')
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64) super(GDELTDataset, self).__init__(name='GDELT',
if self.mode == 'train': url=_url,
self._load(train_data) raw_dir=raw_dir,
elif self.mode == 'valid': force_reload=force_reload,
val_data = loadtxt(os.path.join( verbose=verbose)
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1 def process(self):
self._load(np.concatenate([train_data, val_data], axis=0)) file_path = os.path.join(self.raw_path, self.mode + '.txt')
elif self.mode == 'test': self.data = loadtxt(file_path, delimiter='\t').astype(np.int64)
val_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'valid.txt'), delimiter='\t').astype(np.int64)
test_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'test.txt'), delimiter='\t').astype(np.int64)
train_data[:, 3] = -1
val_data[:, 3] = -1
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data):
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples. # totally 137 samples. The cutoff below has exactly 137 samples.
self.data = data self.time_index = np.floor(self.data[:, 3] / 15).astype(np.int64)
self.time_index = np.floor(data[:, 3]/15).astype(np.int64) self._start_time = self.time_index.min()
self.start_time = self.time_index[self.time_index != -1].min() self._end_time = self.time_index.max()
self.end_time = self.time_index.max()
def has_cache(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
return os.path.exists(info_path)
def save(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
save_info(info_path, {'data': self.data,
'time_index': self.time_index,
'start_time': self.start_time,
'end_time': self.end_time})
def load(self):
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
info = load_info(info_path)
self.data, self.time_index, self._start_time, self._end_time = \
info['data'], info['time_index'], info['start_time'], info['end_time']
@property
def start_time(self):
r""" Start time of events in the temporal graph
def __getitem__(self, idx): Returns
if idx >= len(self) or idx < 0: -------
int
"""
return self._start_time
@property
def end_time(self):
r""" End time of events in the temporal graph
Returns
-------
int
"""
return self._end_time
def __getitem__(self, t):
r""" Get graph by with events before time `t + self.start_time`
Parameters
----------
t : int
Time, its value must be in range [0, `self.end_time` - `self.start_time`]
Returns
-------
dgl.DGLGraph
graph structure and edge feature
- edata['rel_type']: edge type
"""
if t >= len(self) or t < 0:
raise IndexError("Index out of range") raise IndexError("Index out of range")
i = idx + self.start_time i = t + self.start_time
row_mask = self.time_index <= i row_mask = self.time_index <= i
edges = self.data[row_mask][:, [0, 2]] edges = self.data[row_mask][:, [0, 2]]
rate = self.data[row_mask][:, 1] rate = self.data[row_mask][:, 1]
g = convert.graph((edges[:, 0], edges[:, 1])) g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = rate.reshape(-1, 1) g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
return g return g
def __len__(self): def __len__(self):
return self.end_time - self.start_time + 1 r"""Number of graphs in the dataset"""
return self._end_time - self._start_time + 1
@property
def num_nodes(self):
return 23033
@property @property
def is_temporal(self): def is_temporal(self):
r""" Does the dataset contain temporal graphs
Returns
-------
bool
"""
return True return True
GDELT = GDELTDataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment