prepare_visualgenome.py 3.12 KB
Newer Older
1
"""Prepare Visual Genome datasets"""
2
3
import argparse
import json
4
import os
5
6
import pickle
import random
7
8
import shutil
import zipfile
9

10
11
12
import tqdm
from gluoncv.utils import download, makedirs

13
14
_TARGET_DIR = os.path.expanduser("~/.mxnet/datasets/visualgenome")

15
16
17

def parse_args():
    parser = argparse.ArgumentParser(
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        description="Initialize Visual Genome dataset.",
        epilog="Example: python visualgenome.py --download-dir ~/visualgenome",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--download-dir",
        type=str,
        default="~/visualgenome/",
        help="dataset directory on disk",
    )
    parser.add_argument(
        "--no-download",
        action="store_true",
        help="disable automatic download if set",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="overwrite downloaded files if set, in case they are corrupted",
    )
38
39
40
    args = parser.parse_args()
    return args

41

42
43
def download_vg(path, overwrite=False):
    _DOWNLOAD_URLS = [
44
45
46
47
48
49
50
51
        (
            "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
            "a055367f675dd5476220e9b93e4ca9957b024b94",
        ),
        (
            "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
            "2add3aab77623549e92b7f15cda0308f50b64ecf",
        ),
52
53
54
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
55
56
57
        filename = download(
            url, path=path, overwrite=overwrite, sha1_hash=checksum
        )
58
        # extract
59
        if filename.endswith("zip"):
60
61
62
            with zipfile.ZipFile(filename) as zf:
                zf.extractall(path=path)
    # move all images into folder `VG_100K`
63
64
    vg_100k_path = os.path.join(path, "VG_100K")
    vg_100k_2_path = os.path.join(path, "VG_100K_2")
65
66
    files_2 = os.listdir(vg_100k_2_path)
    for fl in files_2:
67
68
69
70
        shutil.move(
            os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)
        )

71
72

def download_json(path, overwrite=False):
73
74
    url = "https://data.dgl.ai/dataset/vg.zip"
    output = "vg.zip"
75
76
77
    download(url, path=path)
    with zipfile.ZipFile(output) as zf:
        zf.extractall(path=path)
78
    json_path = os.path.join(path, "vg")
79
80
    json_files = os.listdir(json_path)
    for fl in json_files:
81
        shutil.move(os.path.join(json_path, fl), os.path.join(path, fl))
82
83
    os.rmdir(json_path)

84
85

if __name__ == "__main__":
86
87
88
89
    args = parse_args()
    path = os.path.expanduser(args.download_dir)
    if not os.path.isdir(path):
        if args.no_download:
90
91
92
93
94
95
96
97
            raise ValueError(
                (
                    "{} is not a valid directory, make sure it is present."
                    ' Or you should not disable "--no-download" to grab it'.format(
                        path
                    )
                )
            )
98
99
100
101
102
        else:
            download_vg(path, overwrite=args.overwrite)
            download_json(path, overwrite=args.overwrite)

    # make symlink
103
    makedirs(os.path.expanduser("~/.mxnet/datasets"))
104
105
106
    if os.path.isdir(_TARGET_DIR):
        os.rmdir(_TARGET_DIR)
    os.symlink(path, _TARGET_DIR)