prepare_visualgenome.py 2.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Prepare Visual Genome datasets"""
import os
import shutil
import argparse
import zipfile
import random
import json
import tqdm
import pickle
from gluoncv.utils import download, makedirs

_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome')

def parse_args():
    parser = argparse.ArgumentParser(
        description='Initialize Visual Genome dataset.',
        epilog='Example: python visualgenome.py --download-dir ~/visualgenome',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--download-dir', type=str, default='~/visualgenome/',
                        help='dataset directory on disk')
    parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
    parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
    args = parser.parse_args()
    return args

def download_vg(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip',
         'a055367f675dd5476220e9b93e4ca9957b024b94'),
        ('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip',
         '2add3aab77623549e92b7f15cda0308f50b64ecf'),
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        if filename.endswith('zip'):
            with zipfile.ZipFile(filename) as zf:
                zf.extractall(path=path)
    # move all images into folder `VG_100K`
    vg_100k_path = os.path.join(path, 'VG_100K')
    vg_100k_2_path = os.path.join(path, 'VG_100K_2')
    files_2 = os.listdir(vg_100k_2_path)
    for fl in files_2:
        shutil.move(os.path.join(vg_100k_2_path, fl),
                    os.path.join(vg_100k_path, fl))

def download_json(path, overwrite=False):
    url = 'https://data.dgl.ai/dataset/vg.zip'
    output = 'vg.zip'
    download(url, path=path)
    with zipfile.ZipFile(output) as zf:
        zf.extractall(path=path)
    json_path = os.path.join(path, 'vg')
    json_files = os.listdir(json_path)
    for fl in json_files:
        shutil.move(os.path.join(json_path, fl),
                    os.path.join(path, fl))
    os.rmdir(json_path)

if __name__ == '__main__':
    args = parse_args()
    path = os.path.expanduser(args.download_dir)
    if not os.path.isdir(path):
        if args.no_download:
            raise ValueError(('{} is not a valid directory, make sure it is present.'
                              ' Or you should not disable "--no-download" to grab it'.format(path)))
        else:
            download_vg(path, overwrite=args.overwrite)
            download_json(path, overwrite=args.overwrite)

    # make symlink
    makedirs(os.path.expanduser('~/.mxnet/datasets'))
    if os.path.isdir(_TARGET_DIR):
        os.rmdir(_TARGET_DIR)
    os.symlink(path, _TARGET_DIR)