Unverified Commit 2b9d06b6 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Data] Retry downloading once if dataset initialization fails (#1520)



* dataloader guard

* more fixes

* sets overwrite to true by default

* Update utils.py

* fix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 90d2118d
......@@ -4,6 +4,7 @@ import os
import datetime
from .utils import get_download_dir, download, extract_archive
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
......@@ -28,14 +29,16 @@ class BitcoinOTC(object):
self.dir = get_download_dir()
self.zip_path = os.path.join(
self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv.gz")
download(self._url, path=self.zip_path)
extract_archive(self.zip_path, os.path.join(
self.dir, 'bitcoin'))
self.path = os.path.join(
self.dir, 'bitcoin', "soc-sign-bitcoinotc.csv")
self.graphs = []
self._load(self.path)
def _download_and_extract(self):
download(self._url, path=self.zip_path)
extract_archive(self.zip_path, os.path.join(self.dir, 'bitcoin'))
@retry_method_with_fix(_download_and_extract)
def _load(self, filename):
data = np.loadtxt(filename, delimiter=',').astype(np.int64)
data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min()
......
......@@ -12,6 +12,7 @@ from collections import defaultdict
from ..utils import mol_to_complete_graph, atom_type_one_hot, \
atom_hybridization_one_hot, atom_is_aromatic
from ...utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from ....utils import retry_method_with_fix
from .... import backend as F
from ....contrib.deprecation import deprecated
......@@ -199,18 +200,25 @@ class TencentAlchemyDataset(object):
else:
file_name = "%s_single_sdf" % (mode)
self._file_dir = file_dir
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_get_dgl_url(self._url + file_name + '.zip'), path=str(self.zip_file_path))
self._file_name = file_name
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _download_and_extract(self):
download(
_get_dgl_url(self._url + self._file_name + '.zip'),
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.extractall(self._file_dir)
archive.close()
self._load(mol_to_graph, node_featurizer, edge_featurizer)
@retry_method_with_fix(_download_and_extract)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "%s_graphs.bin" % self.mode))
......
......@@ -5,6 +5,7 @@ import pandas as pd
from ..utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
from ...utils import get_download_dir, download, _get_dgl_url, extract_archive
from ....utils import retry_method_with_fix
from .... import backend as F
from ....contrib.deprecation import deprecated
......@@ -80,8 +81,6 @@ class PDBBind(object):
root_dir_path = get_download_dir()
data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
extracted_data_path = root_dir_path + '/pdbbind_v2015'
download(_get_dgl_url(self._url), path=data_path)
extract_archive(data_path, extracted_data_path)
if subset == 'core':
index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
......@@ -92,6 +91,9 @@ class PDBBind(object):
'Expect the subset_choice to be either '
'core or refined, got {}'.format(subset))
self._data_path = data_path
self._extracted_data_path = extracted_data_path
self._preprocess(extracted_data_path, index_label_file, load_binding_pocket,
add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes)
......@@ -135,6 +137,11 @@ class PDBBind(object):
self.protein_mols.append(protein_mol)
self.protein_coordinates.append(protein_coordinates)
def _download_and_extract(self):
download(_get_dgl_url(self._url), path=self._data_path)
extract_archive(self._data_path, self._extracted_data_path)
@retry_method_with_fix(_download_and_extract)
def _preprocess(self, root_path, index_label_file, load_binding_pocket,
add_hydrogens, sanitize, calc_charges, remove_hs, use_conformation,
construct_graph_and_featurize, zero_padding, num_processes):
......
......@@ -4,6 +4,7 @@ import sys
from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url
from ....utils import retry_method_with_fix
from ....base import dgl_warning
from ....contrib.deprecation import deprecated
......@@ -42,9 +43,15 @@ class PubChemBioAssayAromaticity(MoleculeCSVDataset):
self._url = 'dataset/pubchem_bioassay_aromaticity.csv'
data_path = get_download_dir() + '/pubchem_bioassay_aromaticity.csv'
download(_get_dgl_url(self._url), path=data_path)
df = pd.read_csv(data_path)
self._data_path = data_path
self._load(data_path, smiles_to_graph, node_featurizer, edge_featurizer, load)
def _download(self):
download(_get_dgl_url(self._url), path=self._data_path)
@retry_method_with_fix(_download)
def _load(self, data_path, smiles_to_graph, node_featurizer, edge_featurizer, load):
df = pd.read_csv(data_path)
super(PubChemBioAssayAromaticity, self).__init__(
df, smiles_to_graph, node_featurizer, edge_featurizer, "cano_smiles",
"pubchem_aromaticity_dglgraph.bin", load=load)
......@@ -4,6 +4,7 @@ from .csv_dataset import MoleculeCSVDataset
from ..utils import smiles_to_bigraph
from ...utils import get_download_dir, download, _get_dgl_url
from .... import backend as F
from ....utils import retry_method_with_fix
from ....base import dgl_warning
from ....contrib.deprecation import deprecated
......@@ -55,7 +56,15 @@ class Tox21(MoleculeCSVDataset):
self._url = 'dataset/tox21.csv.gz'
data_path = get_download_dir() + '/tox21.csv.gz'
download(_get_dgl_url(self._url), path=data_path)
self._data_path = data_path
self._load(data_path, smiles_to_graph, node_featurizer, edge_featurizer, load)
def _download(self):
download(_get_dgl_url(self._url), path=self._data_path)
@retry_method_with_fix(_download)
def _load(self, data_path, smiles_to_graph, node_featurizer, edge_featurizer, load):
df = pd.read_csv(data_path)
self.id = df['mol_id']
......
......@@ -12,6 +12,7 @@ import scipy.sparse as sp
import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
from ..graph import batch as graph_batch
......@@ -48,10 +49,13 @@ class CitationGraphDataset(object):
self.name = name
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_get_dgl_url(_urls[name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
self._load()
def _download_and_extract(self):
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
@retry_method_with_fix(_download_and_extract)
def _load(self):
"""Loads input data from gcn/data directory
......@@ -307,10 +311,13 @@ class CoraBinary(object):
self.dir = get_download_dir()
self.name = 'cora_binary'
self.zip_file_path='{}/{}.zip'.format(self.dir, self.name)
self._load()
def _download_and_extract(self):
download(_get_dgl_url(_urls[self.name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, self.name))
self._load()
@retry_method_with_fix(_download_and_extract)
def _load(self):
root = '{}/{}'.format(self.dir, self.name)
# load graphs
......
......@@ -4,6 +4,7 @@ import os
import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
......@@ -39,10 +40,6 @@ class GDELT(object):
self.dir = get_download_dir()
self.mode = mode
# self.graphs = []
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
train_data = loadtxt(os.path.join(
self.dir, 'GDELT', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train':
......@@ -62,6 +59,13 @@ class GDELT(object):
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'GDELT', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data):
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
......
......@@ -13,6 +13,7 @@ import numpy as np
from .. import backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
......@@ -48,7 +49,7 @@ class GINDataset(object):
self.name = name # MUTAG
self.ds_name = 'nig'
self.extract_dir = self._download()
self.extract_dir = self._get_extract_dir()
self.file = self._file_path()
self.self_loop = self_loop
......@@ -101,20 +102,22 @@ class GINDataset(object):
"""
return self.graphs[idx], self.labels[idx]
def _get_extract_dir(self):
return os.path.join(get_download_dir(), "{}".format(self.ds_name))
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir, "{}.zip".format(self.ds_name))
# TODO move to dgl host _get_dgl_url
download(_url, path=zip_file_path)
extract_dir = os.path.join(
download_dir, "{}".format(self.ds_name))
extract_dir = self._get_extract_dir()
extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self):
return os.path.join(self.extract_dir, "dataset", self.name, "{}.txt".format(self.name))
@retry_method_with_fix(_download)
def _load(self):
""" Loads input dataset from dataset/NAME/NAME.txt file
......
......@@ -2,6 +2,7 @@ import scipy.sparse as sp
import numpy as np
import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
__all__=["AmazonCoBuy", "Coauthor", 'CoraFull']
......@@ -24,12 +25,15 @@ class GNNBenchmarkDataset(object):
self.dir = get_download_dir()
self.path = os.path.join(
self.dir, 'gnn_benckmark', self._url[name.lower()].split('/')[-1])
download(self._url[name.lower()], path=self.path)
self._name = name
g = self.load_npz(self.path)
self.data = [g]
@staticmethod
def load_npz(file_name):
def _download(self):
download(self._url[self._name.lower()], path=self.path)
@retry_method_with_fix(_download)
def load_npz(self, file_name):
with np.load(file_name, allow_pickle=True) as loader:
loader = dict(loader)
num_nodes = loader['adj_shape'][0]
......
......@@ -5,6 +5,7 @@ import datetime
import warnings
from .utils import get_download_dir, download, extract_archive, loadtxt
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
......@@ -40,10 +41,6 @@ class ICEWS18(object):
self.dir = get_download_dir()
self.mode = mode
self.graphs = []
for dname in self._url:
dpath = os.path.join(
self.dir, 'ICEWS18', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
train_data = loadtxt(os.path.join(
self.dir, 'ICEWS18', 'train.txt'), delimiter='\t').astype(np.int64)
if self.mode == 'train':
......@@ -63,6 +60,13 @@ class ICEWS18(object):
self._load(np.concatenate(
[train_data, val_data, test_data], axis=0))
def _download(self):
for dname in self._url:
dpath = os.path.join(
self.dir, 'ICEWS18', self._url[dname.lower()].split('/')[-1])
download(self._url[dname.lower()], path=dpath)
@retry_method_with_fix(_download)
def _load(self, data):
num_nodes = 23033
# The source code is not released, but the paper indicates there're
......
......@@ -7,6 +7,7 @@ import networkx as nx
from networkx.readwrite import json_graph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
_url = 'dataset/ppi.zip'
......@@ -32,9 +33,18 @@ class PPIDataset(object):
"""
assert mode in ['train', 'valid', 'test']
self.mode = mode
self._name = 'ppi'
self._dir = get_download_dir()
self._zip_file_path = '{}/{}.zip'.format(self._dir, self._name)
self._load()
self._preprocess()
def _download(self):
download(_get_dgl_url(_url), path=self._zip_file_path)
extract_archive(self._zip_file_path,
'{}/{}'.format(self._dir, self._name))
@retry_method_with_fix(_download)
def _load(self):
"""Loads input data.
......@@ -52,34 +62,28 @@ class PPIDataset(object):
object and the length of it is equal the number of nodes,
it's like [1, 1, 2, 1...20].
"""
name = 'ppi'
dir = get_download_dir()
zip_file_path = '{}/{}.zip'.format(dir, name)
download(_get_dgl_url(_url), path=zip_file_path)
extract_archive(zip_file_path,
'{}/{}'.format(dir, name))
print('Loading G...')
if self.mode == 'train':
with open('{}/ppi/train_graph.json'.format(dir)) as jsonfile:
with open('{}/ppi/train_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/train_labels.npy'.format(dir))
self.features = np.load('{}/ppi/train_feats.npy'.format(dir))
self.labels = np.load('{}/ppi/train_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/train_feats.npy'.format(self._dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(dir))
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(self._dir))
if self.mode == 'valid':
with open('{}/ppi/valid_graph.json'.format(dir)) as jsonfile:
with open('{}/ppi/valid_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/valid_labels.npy'.format(dir))
self.features = np.load('{}/ppi/valid_feats.npy'.format(dir))
self.labels = np.load('{}/ppi/valid_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/valid_feats.npy'.format(self._dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(dir))
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(self._dir))
if self.mode == 'test':
with open('{}/ppi/test_graph.json'.format(dir)) as jsonfile:
with open('{}/ppi/test_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/test_labels.npy'.format(dir))
self.features = np.load('{}/ppi/test_feats.npy'.format(dir))
self.labels = np.load('{}/ppi/test_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/test_feats.npy'.format(self._dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(dir))
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(self._dir))
def _preprocess(self):
if self.mode == 'train':
......@@ -169,4 +173,4 @@ class LegacyPPIDataset(PPIDataset):
if self.mode == 'valid':
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
if self.mode == 'test':
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
\ No newline at end of file
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
......@@ -3,6 +3,7 @@ import numpy as np
import os
from .utils import get_download_dir, download
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
class QM7b(object):
......@@ -21,10 +22,13 @@ class QM7b(object):
def __init__(self):
self.dir = get_download_dir()
self.path = os.path.join(self.dir, 'qm7b', "qm7b.mat")
download(self._url, path=self.path)
self.graphs = []
self._load(self.path)
def _download(self):
download(self._url, path=self.path)
@retry_method_with_fix(_download)
def _load(self, filename):
data = io.loadmat(self.path)
labels = data['T']
......
......@@ -17,6 +17,7 @@ import numpy as np
import dgl
import dgl.backend as F
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
__all__ = ['AIFB', 'MUTAG', 'BGS', 'AM']
......@@ -108,9 +109,17 @@ class RDFGraphDataset:
insert_reverse=True):
download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, '{}.zip'.format(name))
download(url, path=zip_file_path)
self._dir = os.path.join(download_dir, name)
extract_archive(zip_file_path, self._dir)
self._url = url
self._zip_file_path = zip_file_path
self._load(print_every, insert_reverse, force_reload)
def _download(self):
download(self._url, path=self._zip_file_path)
extract_archive(self._zip_file_path, self._dir)
@retry_method_with_fix(_download)
def _load(self, print_every, insert_reverse, force_reload):
self._print_every = print_every
self._insert_reverse = insert_reverse
if not force_reload and self.has_cache():
......
......@@ -4,6 +4,7 @@ import scipy.sparse as sp
import numpy as np
import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
......@@ -14,14 +15,24 @@ class RedditDataset(object):
if self_loop:
self_loop_str = "_self_loop"
zip_file_path = os.path.join(download_dir, "reddit{}.zip".format(self_loop_str))
download(_get_dgl_url("dataset/reddit{}.zip".format(self_loop_str)), path=zip_file_path)
extract_dir = os.path.join(download_dir, "reddit{}".format(self_loop_str))
extract_archive(zip_file_path, extract_dir)
self._url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str))
self._zip_file_path = zip_file_path
self._extract_dir = extract_dir
self._self_loop_str = self_loop_str
def _download(self):
download(self._url, path=self._zip_file_path)
extract_archive(self._zip_file_path, self._extract_dir)
@retry_method_with_fix(_download)
def _load(self):
# graph
coo_adj = sp.load_npz(os.path.join(extract_dir, "reddit{}_graph.npz".format(self_loop_str)))
coo_adj = sp.load_npz(os.path.join(
self._extract_dir, "reddit{}_graph.npz".format(self._self_loop_str)))
self.graph = DGLGraph(coo_adj, readonly=True)
# features and labels
reddit_data = np.load(os.path.join(extract_dir, "reddit_data.npz"))
reddit_data = np.load(os.path.join(self._extract_dir, "reddit_data.npz"))
self.features = reddit_data["feature"]
self.labels = reddit_data["label"]
self.num_labels = 41
......
......@@ -14,6 +14,7 @@ import os
from .. import backend as F
from ..graph import DGLGraph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
__all__ = ['SSTBatch', 'SST']
......@@ -55,14 +56,17 @@ class SST(object):
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
self.trees = []
self.num_classes = 5
print('Preprocessing...')
self._load()
print('Dataset creation finished. #Trees:', len(self.trees))
def _download(self):
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
@retry_method_with_fix(_download)
def _load(self):
from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file
......
......@@ -4,6 +4,7 @@ import os
import random
from .utils import download, extract_archive, get_download_dir, loadtxt
from ..utils import retry_method_with_fix
from ..graph import DGLGraph
class LegacyTUDataset(object):
......@@ -28,7 +29,24 @@ class LegacyTUDataset(object):
self.name = name
self.hidden_size = hidden_size
self.extract_dir = self._download()
self.extract_dir = self._get_extract_dir()
self._load()
def _get_extract_dir(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
return extract_dir
def _download(self):
download(self._url.format(self.name), path=zip_file_path)
extract_archive(zip_file_path, extract_dir)
@retry_method_with_fix(_download)
def _load(self):
self.data_mode = None
self.max_allow_node = max_allow_node
......@@ -124,17 +142,6 @@ class LegacyTUDataset(object):
def __len__(self):
return len(self.graph_lists)
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
download(self._url.format(self.name), path=zip_file_path)
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self, category):
return os.path.join(self.extract_dir, self.name,
"{}_{}.txt".format(self.name, category))
......
......@@ -5,8 +5,6 @@ import os
import sys
import hashlib
import warnings
import zipfile
import tarfile
import numpy as np
import warnings
import requests
......@@ -80,7 +78,7 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, log=True):
def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_ssl=True, log=True):
"""Download a given URL.
Codes borrowed from mxnet/gluon/utils.py
......@@ -94,6 +92,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
current directory with the same name as in url.
overwrite : bool, optional
Whether to overwrite the destination file if it already exists.
By default always overwrites the downloaded file.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
......@@ -190,7 +189,7 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash
def extract_archive(file, target_dir):
def extract_archive(file, target_dir, overwrite=False):
"""Extract archive file.
Parameters
......@@ -199,18 +198,29 @@ def extract_archive(file, target_dir):
Absolute path of the archive file.
target_dir : str
Target directory of the archive to be uncompressed.
overwrite : bool, default True
Whether to overwrite the contents inside the directory.
By default always overwrites.
"""
if os.path.exists(target_dir):
if os.path.exists(target_dir) and not overwrite:
return
if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
archive = tarfile.open(file, 'r')
print('Extracting file to {}'.format(target_dir))
if file.endswith('.tar.gz') or file.endswith('.tar') or file.endswith('.tgz'):
import tarfile
with tarfile.open(file, 'r') as archive:
archive.extractall(path=target_dir)
elif file.endswith('.gz'):
import gzip
import shutil
with gzip.open(file, 'rb') as f_in:
with open(file[:-3], 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
elif file.endswith('.zip'):
archive = zipfile.ZipFile(file, 'r')
import zipfile
with zipfile.ZipFile(file, 'r') as archive:
archive.extractall(path=target_dir)
else:
raise Exception('Unrecognized file type: ' + file)
print('Extracting file to {}'.format(target_dir))
archive.extractall(path=target_dir)
archive.close()
def get_download_dir():
......
......@@ -568,3 +568,33 @@ def check_eq_shape(input_):
raise DGLError("The feature shape of source nodes: {} \
should be equal to the feature shape of destination \
nodes: {}.".format(src_feat_shape, dst_feat_shape))
def retry_method_with_fix(fix_method):
"""Decorator that executes a fix method before retrying again when the decorated method
fails once with any exception.
If the decorated method fails again, the execution fails with that exception.
Notes
-----
This decorator only works on class methods, and the fix function must also be a class method.
It would not work on functions.
Parameters
----------
fix_func : callable
The fix method to execute. It should not accept any arguments. Its return values are
ignored.
"""
def _creator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
# pylint: disable=W0703,bare-except
try:
return func(self, *args, **kwargs)
except:
fix_method(self)
return func(self, *args, **kwargs)
return wrapper
return _creator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment