Unverified Commit 2b9d06b6 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Data] Retry downloading once if dataset initialization fails (#1520)



* dataloader guard

* more fixes

* sets overwrite to true by default

* Update utils.py

* fix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 90d2118d
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import datetime import datetime
from .utils import get_download_dir, download, extract_archive from .utils import get_download_dir, download, extract_archive
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
...@@ -28,14 +29,16 @@ class BitcoinOTC(object): ...@@ -28,14 +29,16 @@ class BitcoinOTC(object):
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_path = os.path.join( self.zip_path = os.path.join(
self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv.gz") self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv.gz")
download(self._url, path=self.zip_path)
extract_archive(self.zip_path, os.path.join(
self.dir, 'bitcoin'))
self.path = os.path.join( self.path = os.path.join(
self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv") self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv")
self.graphs = [] self.graphs = []
self._load(self.path) self._load(self.path)
def _download_and_extract(self):
download(self._url, path=self.zip_path)
extract_archive(self.zip_path, os.path.join(self.dir, 'bitcoin'))
@retry_method_with_fix(_download_and_extract)
def _load(self, filename): def _load(self, filename):
data = np.loadtxt(filename, delimiter=',').astype(np.int64) data = np.loadtxt(filename, delimiter=',').astype(np.int64)
data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min() data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min()
......
...@@ -12,6 +12,7 @@ from collections import defaultdict ...@@ -12,6 +12,7 @@ from collections import defaultdict
from ..utils import mol_to_complete_graph, atom_type_one_hot, \ from ..utils import mol_to_complete_graph, atom_type_one_hot, \
atom_hybridization_one_hot, atom_is_aromatic atom_hybridization_one_hot, atom_is_aromatic
from ...utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs from ...utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from ....utils import retry_method_with_fix
from .... import backend as F from .... import backend as F
from ....contrib.deprecation import deprecated from ....contrib.deprecation import deprecated
...@@ -199,18 +200,25 @@ class TencentAlchemyDataset(object): ...@@ -199,18 +200,25 @@ class TencentAlchemyDataset(object):
else: else:
file_name = "%s_single_sdf" % (mode) file_name = "%s_single_sdf" % (mode)
self._file_dir = file_dir
self.file_dir = pathlib.Path(file_dir, file_name) self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/' self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip') self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_get_dgl_url(self._url + file_name + '.zip'), path=str(self.zip_file_path)) self._file_name = file_name
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _download_and_extract(self):
download(
_get_dgl_url(self._url + self._file_name + '.zip'),
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)): if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path) archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir) archive.extractall(self._file_dir)
archive.close() archive.close()
self._load(mol_to_graph, node_featurizer, edge_featurizer) @retry_method_with_fix(_download_and_extract)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer): def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if self.load: if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode)) self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
......
...@@ -5,6 +5,7 @@ import pandas as pd ...@@ -5,6 +5,7 @@ import pandas as pd
from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
from ...utils import get_download_dir, download, _get_dgl_url, extract_archive from ...utils import get_download_dir, download, _get_dgl_url, extract_archive
from ....utils import retry_method_with_fix
from .... import backend as F from .... import backend as F
from ....contrib.deprecation import deprecated from ....contrib.deprecation import deprecated
...@@ -80,8 +81,6 @@ class PDBBind(object): ...@@ -80,8 +81,6 @@ class PDBBind(object):
root_dir_path = get_download_dir() root_dir_path = get_download_dir()
data_path = root_dir_path + '/pdbbind_v2015.tar.gz' data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
extracted_data_path = root_dir_path + '/pdbbind_v2015' extracted_data_path = root_dir_path + '/pdbbind_v2015'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
if subset == 'core': if subset == 'core':
index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013' index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
...@@ -92,6 +91,9 @@ class PDBBind(object): ...@@ -92,6 +91,9 @@ class PDBBind(object):
'Expect the subset_choice to be either ' 'Expect the subset_choice to be either '
'core or refined, got {}'.format(subset)) 'core or refined, got {}'.format(subset))
self._data_path = data_path
self._extracted_data_path = extracted_data_path
self._preprocess(extracted_data_path, index_label_file, load_binding_pocket, self._preprocess(extracted_data_path, index_label_file, load_binding_pocket,
add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation, add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes) construct_graph_and_featurize, zero_padding, num_processes)
...@@ -135,6 +137,11 @@ class PDBBind(object): ...@@ -135,6 +137,11 @@ class PDBBind(object):
self.protein_mols.append(protein_mol) self.protein_mols.append(protein_mol)
self.protein_coordinates.append(protein_coordinates) self.protein_coordinates.append(protein_coordinates)
def _download_and_extract(self):
download(_get_dgl_url(self._url), path=self._data_path)
extract_archive(self._data_path, self._extracted_data_path)
@retry_method_with_fix(_download_and_extract)
def _preprocess(self, root_path, index_label_file, load_binding_pocket, def _preprocess(self, root_path, index_label_file, load_binding_pocket,
add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation, add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes): construct_graph_and_featurize, zero_padding, num_processes):
......
...@@ -4,6 +4,7 @@ import sys ...@@ -4,6 +4,7 @@ import sys
from .csv_dataset import MoleculeCSVDataset from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url from ...utils import get_download_dir, download, _get_dgl_url
from ....utils import retry_method_with_fix
from ....base import dgl_warning from ....base import dgl_warning
from ....contrib.deprecation import deprecated from ....contrib.deprecation import deprecated
...@@ -42,9 +43,15 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset): ...@@ -42,9 +43,15 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
self._url = 'dataset/pubchem_bioassay_aromaticity.csv' self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv' data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path) self._data_path = data_path
df = pd.read_csv(data_path) self._load(data_path, smiles_to_graph, node_featurizer, edge_featurizer, load)
def _download(self):
download(_get_dgl_url(self._url), path=self._data_path)
@retry_method_with_fix(_download)
def _load(self, data_path, smiles_to_graph, node_featurizer, edge_featurizer, load):
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__( super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles", df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load) "pubchem_aromaticity_dglgraph.bin", load=load)
...@@ -4,6 +4,7 @@ from .csv_dataset import MoleculeCSVDataset ...@@ -4,6 +4,7 @@ from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url from ...utils import get_download_dir, download, _get_dgl_url
from .... import backend as F from .... import backend as F
from ....utils import retry_method_with_fix
from ....base import dgl_warning from ....base import dgl_warning
from ....contrib.deprecation import deprecated from ....contrib.deprecation import deprecated
...@@ -55,7 +56,15 @@ class Tox21(MoleculeCSVDataset): ...@@ -55,7 +56,15 @@ class Tox21(MoleculeCSVDataset):
self._url = 'dataset/tox21.csv.gz' self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz' data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path) self._data_path = data_path
self._load(data_path, smiles_to_graph, node_featurizer, edge_featurizer, load)
def _download(self):
download(_get_dgl_url(self._url), path=self._data_path)
@retry_method_with_fix(_download)
def _load(self, data_path, smiles_to_graph, node_featurizer, edge_featurizer, load):
df = pd.read_csv(data_path) df = pd.read_csv(data_path)
self.id = df['mol_id'] self.id = df['mol_id']
......
...@@ -12,6 +12,7 @@ import scipy.sparse as sp ...@@ -12,6 +12,7 @@ import scipy.sparse as sp
import os, sys import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
from ..graph import batch as graph_batch from ..graph import batch as graph_batch
...@@ -48,10 +49,13 @@ class CitationGraphDataset(object): ...@@ -48,10 +49,13 @@ class CitationGraphDataset(object):
self.name = name self.name = name
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name) self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_get_dgl_url(_urls[name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
self._load() self._load()
def _download_and_extract(self):
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
@retry_method_with_fix(_download_and_extract)
def _load(self): def _load(self):
"""Loads input data from gcn/data directory """Loads input data from gcn/data directory
...@@ -307,10 +311,13 @@ class CoraBinary(object): ...@@ -307,10 +311,13 @@ class CoraBinary(object):
self.dir = get_download_dir() self.dir = get_download_dir()
self.name = 'cora_binary' self.name = 'cora_binary'
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name) self.zip_file_path='{}/{}.zip'.format(self.dir, self.name)
self._load()
def _download_and_extract(self):
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path) download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name)) extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
self._load()
@retry_method_with_fix(_download_and_extract)
def _load(self): def _load(self):
root = '{}/{}'.format(self.dir, self.name) root = '{}/{}'.format(self.dir, self.name)
# load graphs # load graphs
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import datetime import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt from .utils import get_download_dir, download, extract_archive, loadtxt
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
...@@ -39,10 +40,6 @@ class GDELT(object): ...@@ -39,10 +40,6 @@ class GDELT(object):
self.dir = get_download_dir() self.dir = get_download_dir()
self.mode = mode self.mode = mode
# self.graphs = [] # self.graphs = []
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
train_data = loadtxt(os.path.join( train_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64) self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train': if self.mode == 'train':
...@@ -62,6 +59,13 @@ class GDELT(object): ...@@ -62,6 +59,13 @@ class GDELT(object):
self._load(np.concatenate( self._load(np.concatenate(
[train_data, val_data, test_data], axis=0)) [train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data): def _load(self, data):
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples. # totally 137 samples. The cutoff below has exactly 137 samples.
......
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
from .. import backend as F from .. import backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip' _url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
...@@ -48,7 +49,7 @@ class GINDataset(object): ...@@ -48,7 +49,7 @@ class GINDataset(object):
self.name = name # MUTAG self.name = name # MUTAG
self.ds_name = 'nig' self.ds_name = 'nig'
self.extract_dir = self._download() self.extract_dir = self._get_extract_dir()
self.file = self._file_path() self.file = self._file_path()
self.self_loop = self_loop self.self_loop = self_loop
...@@ -101,20 +102,22 @@ class GINDataset(object): ...@@ -101,20 +102,22 @@ class GINDataset(object):
""" """
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
def _get_extract_dir(self):
return os.path.join(get_download_dir(), "{}".format(self.ds_name))
def _download(self): def _download(self):
download_dir = get_download_dir() download_dir = get_download_dir()
zip_file_path = os.path.join( zip_file_path = os.path.join(
download_dir, "{}.zip".format(self.ds_name)) download_dir, "{}.zip".format(self.ds_name))
# TODO move to dgl host _get_dgl_url # TODO move to dgl host _get_dgl_url
download(_url, path=zip_file_path) download(_url, path=zip_file_path)
extract_dir = os.path.join( extract_dir = self._get_extract_dir()
download_dir, "{}".format(self.ds_name))
extract_archive(zip_file_path, extract_dir) extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self): def _file_path(self):
return os.path.join(self.extract_dir, "dataset", self.name, "{}.txt".format(self.name)) return os.path.join(self.extract_dir, "dataset", self.name, "{}.txt".format(self.name))
@retry_method_with_fix(_download)
def _load(self): def _load(self):
""" Loads input dataset from dataset/NAME/NAME.txt file """ Loads input dataset from dataset/NAME/NAME.txt file
......
...@@ -2,6 +2,7 @@ import scipy.sparse as sp ...@@ -2,6 +2,7 @@ import scipy.sparse as sp
import numpy as np import numpy as np
import os import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
__all__=["AmazonCoBuy", "Coauthor", 'CoraFull'] __all__=["AmazonCoBuy", "Coauthor", 'CoraFull']
...@@ -24,12 +25,15 @@ class GNNBenchmarkDataset(object): ...@@ -24,12 +25,15 @@ class GNNBenchmarkDataset(object):
self.dir = get_download_dir() self.dir = get_download_dir()
self.path = os.path.join( self.path = os.path.join(
self.dir, 'gnn_benckmark', self._url[name.lower()].split('/')[-1]) self.dir, 'gnn_benckmark', self._url[name.lower()].split('/')[-1])
download(self._url[name.lower()], path=self.path) self._name = name
g = self.load_npz(self.path) g = self.load_npz(self.path)
self.data = [g] self.data = [g]
@staticmethod def _download(self):
def load_npz(file_name): download(self._url[self._name.lower()], path=self.path)
@retry_method_with_fix(_download)
def load_npz(self, file_name):
with np.load(file_name, allow_pickle=True) as loader: with np.load(file_name, allow_pickle=True) as loader:
loader = dict(loader) loader = dict(loader)
num_nodes = loader['adj_shape'][0] num_nodes = loader['adj_shape'][0]
......
...@@ -5,6 +5,7 @@ import datetime ...@@ -5,6 +5,7 @@ import datetime
import warnings import warnings
from .utils import get_download_dir, download, extract_archive, loadtxt from .utils import get_download_dir, download, extract_archive, loadtxt
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
...@@ -40,10 +41,6 @@ class ICEWS18(object): ...@@ -40,10 +41,6 @@ class ICEWS18(object):
self.dir = get_download_dir() self.dir = get_download_dir()
self.mode = mode self.mode = mode
self.graphs = [] self.graphs = []
for dname in self._url:
dpath = os.path.join(
self.dir, 'ICEWS18', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
train_data = loadtxt(os.path.join( train_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'train.txt'), delimiter='\t').astype(np.int64) self.dir, 'ICEWS18', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train': if self.mode == 'train':
...@@ -63,6 +60,13 @@ class ICEWS18(object): ...@@ -63,6 +60,13 @@ class ICEWS18(object):
self._load(np.concatenate( self._load(np.concatenate(
[train_data, val_data, test_data], axis=0)) [train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'ICEWS18', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data): def _load(self, data):
num_nodes = 23033 num_nodes = 23033
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
......
...@@ -7,6 +7,7 @@ import networkx as nx ...@@ -7,6 +7,7 @@ import networkx as nx
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
_url = 'dataset/ppi.zip' _url = 'dataset/ppi.zip'
...@@ -32,9 +33,18 @@ class PPIDataset(object): ...@@ -32,9 +33,18 @@ class PPIDataset(object):
""" """
assert mode in ['train', 'valid', 'test'] assert mode in ['train', 'valid', 'test']
self.mode = mode self.mode = mode
self._name = 'ppi'
self._dir = get_download_dir()
self._zip_file_path = '{}/{}.zip'.format(self._dir, self._name)
self._load() self._load()
self._preprocess() self._preprocess()
def _download(self):
download(_get_dgl_url(_url), path=self._zip_file_path)
extract_archive(self._zip_file_path,
'{}/{}'.format(self._dir, self._name))
@retry_method_with_fix(_download)
def _load(self): def _load(self):
"""Loads input data. """Loads input data.
...@@ -52,34 +62,28 @@ class PPIDataset(object): ...@@ -52,34 +62,28 @@ class PPIDataset(object):
object and the length of it is equal the number of nodes, object and the length of it is equal the number of nodes,
it's like [1, 1, 2, 1...20]. it's like [1, 1, 2, 1...20].
""" """
name = 'ppi'
dir = get_download_dir()
zip_file_path = '{}/{}.zip'.format(dir, name)
download(_get_dgl_url(_url), path=zip_file_path)
extract_archive(zip_file_path,
'{}/{}'.format(dir, name))
print('Loading G...') print('Loading G...')
if self.mode == 'train': if self.mode == 'train':
with open('{}/ppi/train_graph.json'.format(dir)) as jsonfile: with open('{}/ppi/train_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile) g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/train_labels.npy'.format(dir)) self.labels = np.load('{}/ppi/train_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/train_feats.npy'.format(dir)) self.features = np.load('{}/ppi/train_feats.npy'.format(self._dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data))) self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(dir)) self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(self._dir))
if self.mode == 'valid': if self.mode == 'valid':
with open('{}/ppi/valid_graph.json'.format(dir)) as jsonfile: with open('{}/ppi/valid_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile) g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/valid_labels.npy'.format(dir)) self.labels = np.load('{}/ppi/valid_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/valid_feats.npy'.format(dir)) self.features = np.load('{}/ppi/valid_feats.npy'.format(self._dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data))) self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(dir)) self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(self._dir))
if self.mode == 'test': if self.mode == 'test':
with open('{}/ppi/test_graph.json'.format(dir)) as jsonfile: with open('{}/ppi/test_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile) g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/test_labels.npy'.format(dir)) self.labels = np.load('{}/ppi/test_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/test_feats.npy'.format(dir)) self.features = np.load('{}/ppi/test_feats.npy'.format(self._dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data))) self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(dir)) self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(self._dir))
def _preprocess(self): def _preprocess(self):
if self.mode == 'train': if self.mode == 'train':
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
import os import os
from .utils import get_download_dir, download from .utils import get_download_dir, download
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
class QM7b(object): class QM7b(object):
...@@ -21,10 +22,13 @@ class QM7b(object): ...@@ -21,10 +22,13 @@ class QM7b(object):
def __init__(self): def __init__(self):
self.dir = get_download_dir() self.dir = get_download_dir()
self.path = os.path.join(self.dir, 'qm7b', "qm7b.mat") self.path = os.path.join(self.dir, 'qm7b', "qm7b.mat")
download(self._url, path=self.path)
self.graphs = [] self.graphs = []
self._load(self.path) self._load(self.path)
def _download(self):
download(self._url, path=self.path)
@retry_method_with_fix(_download)
def _load(self, filename): def _load(self, filename):
data = io.loadmat(self.path) data = io.loadmat(self.path)
labels = data['T'] labels = data['T']
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import dgl import dgl
import dgl.backend as F import dgl.backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
__all__ = ['AIFB', 'MUTAG', 'BGS', 'AM'] __all__ = ['AIFB', 'MUTAG', 'BGS', 'AM']
...@@ -108,9 +109,17 @@ class RDFGraphDataset: ...@@ -108,9 +109,17 @@ class RDFGraphDataset:
insert_reverse=True): insert_reverse=True):
download_dir = get_download_dir() download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, '{}.zip'.format(name)) zip_file_path = os.path.join(download_dir, '{}.zip'.format(name))
download(url, path=zip_file_path)
self._dir = os.path.join(download_dir, name) self._dir = os.path.join(download_dir, name)
extract_archive(zip_file_path, self._dir) self._url = url
self._zip_file_path = zip_file_path
self._load(print_every, insert_reverse, force_reload)
def _download(self):
download(self._url, path=self._zip_file_path)
extract_archive(self._zip_file_path, self._dir)
@retry_method_with_fix(_download)
def _load(self, print_every, insert_reverse, force_reload):
self._print_every = print_every self._print_every = print_every
self._insert_reverse = insert_reverse self._insert_reverse = insert_reverse
if not force_reload and self.has_cache(): if not force_reload and self.has_cache():
......
...@@ -4,6 +4,7 @@ import scipy.sparse as sp ...@@ -4,6 +4,7 @@ import scipy.sparse as sp
import numpy as np import numpy as np
import os, sys import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
...@@ -14,14 +15,24 @@ class RedditDataset(object): ...@@ -14,14 +15,24 @@ class RedditDataset(object):
if self_loop: if self_loop:
self_loop_str = "_self_loop" self_loop_str = "_self_loop"
zip_file_path = os.path.join(download_dir, "reddit{}.zip".format(self_loop_str)) zip_file_path = os.path.join(download_dir, "reddit{}.zip".format(self_loop_str))
download(_get_dgl_url("dataset/reddit{}.zip".format(self_loop_str)), path=zip_file_path)
extract_dir = os.path.join(download_dir, "reddit{}".format(self_loop_str)) extract_dir = os.path.join(download_dir, "reddit{}".format(self_loop_str))
extract_archive(zip_file_path, extract_dir) self._url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str))
self._zip_file_path = zip_file_path
self._extract_dir = extract_dir
self._self_loop_str = self_loop_str
def _download(self):
download(self._url, path=self._zip_file_path)
extract_archive(self._zip_file_path, self._extract_dir)
@retry_method_with_fix(_download)
def _load(self):
# graph # graph
coo_adj = sp.load_npz(os.path.join(extract_dir, "reddit{}_graph.npz".format(self_loop_str))) coo_adj = sp.load_npz(os.path.join(
self._extract_dir, "reddit{}_graph.npz".format(self._self_loop_str)))
self.graph = DGLGraph(coo_adj, readonly=True) self.graph = DGLGraph(coo_adj, readonly=True)
# features and labels # features and labels
reddit_data = np.load(os.path.join(extract_dir, "reddit_data.npz")) reddit_data = np.load(os.path.join(self._extract_dir, "reddit_data.npz"))
self.features = reddit_data["feature"] self.features = reddit_data["feature"]
self.labels = reddit_data["label"] self.labels = reddit_data["label"]
self.num_labels = 41 self.num_labels = 41
......
...@@ -14,6 +14,7 @@ import os ...@@ -14,6 +14,7 @@ import os
from .. import backend as F from .. import backend as F
from ..graph import DGLGraph from ..graph import DGLGraph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
__all__ = ['SSTBatch', 'SST'] __all__ = ['SSTBatch', 'SST']
...@@ -55,14 +56,17 @@ class SST(object): ...@@ -55,14 +56,17 @@ class SST(object):
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else '' self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
self.trees = [] self.trees = []
self.num_classes = 5 self.num_classes = 5
print('Preprocessing...') print('Preprocessing...')
self._load() self._load()
print('Dataset creation finished. #Trees:', len(self.trees)) print('Dataset creation finished. #Trees:', len(self.trees))
def _download(self):
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
@retry_method_with_fix(_download)
def _load(self): def _load(self):
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file # load vocab file
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import random import random
from .utils import download, extract_archive, get_download_dir, loadtxt from .utils import download, extract_archive, get_download_dir, loadtxt
from ..utils import retry_method_with_fix
from ..graph import DGLGraph from ..graph import DGLGraph
class LegacyTUDataset(object): class LegacyTUDataset(object):
...@@ -28,7 +29,24 @@ class LegacyTUDataset(object): ...@@ -28,7 +29,24 @@ class LegacyTUDataset(object):
self.name = name self.name = name
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.extract_dir = self._download() self.extract_dir = self._get_extract_dir()
self._load()
def _get_extract_dir(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
return extract_dir
def _download(self):
download(self._url.format(self.name), path=zip_file_path)
extract_archive(zip_file_path, extract_dir)
@retry_method_with_fix(_download)
def _load(self):
self.data_mode = None self.data_mode = None
self.max_allow_node = max_allow_node self.max_allow_node = max_allow_node
...@@ -124,17 +142,6 @@ class LegacyTUDataset(object): ...@@ -124,17 +142,6 @@ class LegacyTUDataset(object):
def __len__(self): def __len__(self):
return len(self.graph_lists) return len(self.graph_lists)
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
download(self._url.format(self.name), path=zip_file_path)
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self, category): def _file_path(self, category):
return os.path.join(self.extract_dir, self.name, return os.path.join(self.extract_dir, self.name,
"{}_{}.txt".format(self.name, category)) "{}_{}.txt".format(self.name, category))
......
...@@ -5,8 +5,6 @@ import os ...@@ -5,8 +5,6 @@ import os
import sys import sys
import hashlib import hashlib
import warnings import warnings
import zipfile
import tarfile
import numpy as np import numpy as np
import warnings import warnings
import requests import requests
...@@ -80,7 +78,7 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None): ...@@ -80,7 +78,7 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)] return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, log=True): def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_ssl=True, log=True):
"""Download a given URL. """Download a given URL.
Codes borrowed from mxnet/gluon/utils.py Codes borrowed from mxnet/gluon/utils.py
...@@ -94,6 +92,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ ...@@ -94,6 +92,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
current directory with the same name as in url. current directory with the same name as in url.
overwrite : bool, optional overwrite : bool, optional
Whether to overwrite the destination file if it already exists. Whether to overwrite the destination file if it already exists.
By default always overwrites the downloaded file.
sha1_hash : str, optional sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match. but doesn't match.
...@@ -190,7 +189,7 @@ def check_sha1(filename, sha1_hash): ...@@ -190,7 +189,7 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash return sha1.hexdigest() == sha1_hash
def extract_archive(file, target_dir): def extract_archive(file, target_dir, overwrite=False):
"""Extract archive file. """Extract archive file.
Parameters Parameters
...@@ -199,18 +198,29 @@ def extract_archive(file, target_dir): ...@@ -199,18 +198,29 @@ def extract_archive(file, target_dir):
Absolute path of the archive file. Absolute path of the archive file.
target_dir : str target_dir : str
Target directory of the archive to be uncompressed. Target directory of the archive to be uncompressed.
overwrite : bool, default True
Whether to overwrite the contents inside the directory.
By default always overwrites.
""" """
if os.path.exists(target_dir): if os.path.exists(target_dir) and not overwrite:
return return
if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'): print('Extracting file to {}'.format(target_dir))
archive = tarfile.open(file, 'r') if file.endswith('.tar.gz') or file.endswith('.tar') or file.endswith('.tgz'):
import tarfile
with tarfile.open(file, 'r') as archive:
archive.extractall(path=target_dir)
elif file.endswith('.gz'):
import gzip
import shutil
with gzip.open(file, 'rb') as f_in:
with open(file[:-3], 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
elif file.endswith('.zip'): elif file.endswith('.zip'):
archive = zipfile.ZipFile(file, 'r') import zipfile
with zipfile.ZipFile(file, 'r') as archive:
archive.extractall(path=target_dir)
else: else:
raise Exception('Unrecognized file type: ' + file) raise Exception('Unrecognized file type: ' + file)
print('Extracting file to {}'.format(target_dir))
archive.extractall(path=target_dir)
archive.close()
def get_download_dir(): def get_download_dir():
......
...@@ -568,3 +568,33 @@ def check_eq_shape(input_): ...@@ -568,3 +568,33 @@ def check_eq_shape(input_):
raise DGLError("The feature shape of source nodes: {} \ raise DGLError("The feature shape of source nodes: {} \
should be equal to the feature shape of destination \ should be equal to the feature shape of destination \
nodes: {}.".format(src_feat_shape, dst_feat_shape)) nodes: {}.".format(src_feat_shape, dst_feat_shape))
def retry_method_with_fix(fix_method):
"""Decorator that executes a fix method before retrying again when the decorated method
fails once with any exception.
If the decorated method fails again, the execution fails with that exception.
Notes
-----
This decorator only works on class methods, and the fix function must also be a class method.
It would not work on functions.
Parameters
----------
fix_func : callable
The fix method to execute. It should not accept any arguments. Its return values are
ignored.
"""
def _creator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
# pylint: disable=W0703,bare-except
try:
return func(self, *args, **kwargs)
except:
fix_method(self)
return func(self, *args, **kwargs)
return wrapper
return _creator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment