import os import torch as th import torch.nn as nn import tqdm class PBar(object): def __enter__(self): self.t = None return self def __call__(self, blockno, readsize, totalsize): if self.t is None: self.t = tqdm.tqdm(total=totalsize) self.t.update(readsize) def __exit__(self, exc_type, exc_value, traceback): self.t.close() class AminerDataset(object): """ Download Aminer Dataset from Amazon S3 bucket. """ def __init__(self, path): self.url = "https://data.dgl.ai/dataset/aminer.zip" if not os.path.exists(os.path.join(path, "aminer.txt")): print("File not found. Downloading from", self.url) self._download_and_extract(path, "aminer.zip") self.fn = os.path.join(path, "aminer.txt") def _download_and_extract(self, path, filename): import shutil, zipfile, zlib import urllib.request from tqdm import tqdm fn = os.path.join(path, filename) with PBar() as pb: urllib.request.urlretrieve(self.url, fn, pb) print("Download finished. Unzipping the file...") with zipfile.ZipFile(fn) as zf: zf.extractall(path) print("Unzip finished.") class CustomDataset(object): """ Custom dataset generated by sampler.py (e.g. NetDBIS) """ def __init__(self, path): self.fn = path