"""Prepare Visual Genome datasets""" import os import shutil import argparse import zipfile import random import json import tqdm import pickle from gluoncv.utils import download, makedirs _TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome') def parse_args(): parser = argparse.ArgumentParser( description='Initialize Visual Genome dataset.', epilog='Example: python visualgenome.py --download-dir ~/visualgenome', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--download-dir', type=str, default='~/visualgenome/', help='dataset directory on disk') parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted') args = parser.parse_args() return args def download_vg(path, overwrite=False): _DOWNLOAD_URLS = [ ('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip', 'a055367f675dd5476220e9b93e4ca9957b024b94'), ('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip', '2add3aab77623549e92b7f15cda0308f50b64ecf'), ] makedirs(path) for url, checksum in _DOWNLOAD_URLS: filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) # extract if filename.endswith('zip'): with zipfile.ZipFile(filename) as zf: zf.extractall(path=path) # move all images into folder `VG_100K` vg_100k_path = os.path.join(path, 'VG_100K') vg_100k_2_path = os.path.join(path, 'VG_100K_2') files_2 = os.listdir(vg_100k_2_path) for fl in files_2: shutil.move(os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)) def download_json(path, overwrite=False): url = 'https://data.dgl.ai/dataset/vg.zip' output = 'vg.zip' download(url, path=path) with zipfile.ZipFile(output) as zf: zf.extractall(path=path) json_path = os.path.join(path, 'vg') json_files = os.listdir(json_path) for fl in json_files: shutil.move(os.path.join(json_path, fl), os.path.join(path, fl)) os.rmdir(json_path) if __name__ == '__main__': args = parse_args() path = os.path.expanduser(args.download_dir) if not os.path.isdir(path): if args.no_download: raise ValueError(('{} is not a valid directory, make sure it is present.' ' Or you should not disable "--no-download" to grab it'.format(path))) else: download_vg(path, overwrite=args.overwrite) download_json(path, overwrite=args.overwrite) # make symlink makedirs(os.path.expanduser('~/.mxnet/datasets')) if os.path.isdir(_TARGET_DIR): os.rmdir(_TARGET_DIR) os.symlink(path, _TARGET_DIR)