process_dataset.py 5.18 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tarfile
from pathlib import Path
from typing import Tuple, Dict, List

from PIL import Image
from tqdm import tqdm

DATASETS_DIR = os.environ.get("DATASETS_DIR", None)
IMAGENET_DIRNAME = "imagenet"
IMAGE_ARCHIVE_FILENAME = "ILSVRC2012_img_val.tar"
DEVKIT_ARCHIVE_FILENAME = "ILSVRC2012_devkit_t12.tar.gz"
LABELS_REL_PATH = "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt"
META_REL_PATH = "ILSVRC2012_devkit_t12/data/meta.mat"

TARGET_SIZE = (224, 224)  # (width, height)
_RESIZE_MIN = 256  # resize preserving aspect ratio to where this is minimal size


def parse_meta_mat(metafile) -> Dict[int, str]:
    import scipy.io

    meta = scipy.io.loadmat(metafile, squeeze_me=True)["synsets"]
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
    idcs, wnids = list(zip(*meta))[:2]
    idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
    return idx_to_wnid


def _process_image(image_file, target_size):
    image = Image.open(image_file)
    original_size = image.size

    # scale image to size where minimal size is _RESIZE_MIN
    scale_factor = max(_RESIZE_MIN / original_size[0], _RESIZE_MIN / original_size[1])
    resize_to = int(original_size[0] * scale_factor), int(original_size[1] * scale_factor)
    resized_image = image.resize(resize_to)

    # central crop of image to target_size
    left, upper = (resize_to[0] - target_size[0]) // 2, (resize_to[1] - target_size[1]) // 2
    cropped_image = resized_image.crop((left, upper, left + target_size[0], upper + target_size[1]))
    return cropped_image


def main():
    import argparse

    parser = argparse.ArgumentParser(description="short_description")
    parser.add_argument(
        "--dataset-dir",
        help="Path to dataset directory where imagenet archives are stored and processed files will be saved.",
        required=False,
        default=DATASETS_DIR,
    )
    parser.add_argument(
        "--target-size",
        help="Size of target image. Format it as <width>,<height>.",
        required=False,
        default=",".join(map(str, TARGET_SIZE)),
    )
    args = parser.parse_args()

    if args.dataset_dir is None:
        raise ValueError(
            "Please set $DATASETS_DIR env variable to point dataset dir with original dataset archives "
            "and where processed files should be stored. Alternatively provide --dataset-dir CLI argument"
        )

    datasets_dir = Path(args.dataset_dir)
    target_size = tuple(map(int, args.target_size.split(",")))

    image_archive_path = datasets_dir / IMAGE_ARCHIVE_FILENAME
    if not image_archive_path.exists():
        raise RuntimeError(
            f"There should be {IMAGE_ARCHIVE_FILENAME} file in {datasets_dir}."
            f"You need to download the dataset from http://www.image-net.org/download."
        )

    devkit_archive_path = datasets_dir / DEVKIT_ARCHIVE_FILENAME
    if not devkit_archive_path.exists():
        raise RuntimeError(
            f"There should be {DEVKIT_ARCHIVE_FILENAME} file in {datasets_dir}."
            f"You need to download the dataset from http://www.image-net.org/download."
        )

    with tarfile.open(devkit_archive_path, mode="r") as devkit_archive_file:
        labels_file = devkit_archive_file.extractfile(LABELS_REL_PATH)
        labels = list(map(int, labels_file.readlines()))

        # map validation labels (idxes from LABELS_REL_PATH) into WNID compatible with training set
        meta_file = devkit_archive_file.extractfile(META_REL_PATH)
        idx_to_wnid = parse_meta_mat(meta_file)
        labels_wnid = [idx_to_wnid[idx] for idx in labels]

        # remap WNID into index in sorted list of all WNIDs - this is how network outputs class
        available_wnids = sorted(set(labels_wnid))
        wnid_to_newidx = {wnid: new_cls for new_cls, wnid in enumerate(available_wnids)}
        labels = [wnid_to_newidx[wnid] for wnid in labels_wnid]

    output_dir = datasets_dir / IMAGENET_DIRNAME
    with tarfile.open(image_archive_path, mode="r") as image_archive_file:
        image_rel_paths = sorted(image_archive_file.getnames())
        for cls, image_rel_path in tqdm(zip(labels, image_rel_paths), total=len(image_rel_paths)):
            output_path = output_dir / str(cls) / image_rel_path
            original_image_file = image_archive_file.extractfile(image_rel_path)
            processed_image = _process_image(original_image_file, target_size)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            processed_image.save(output_path.as_posix())


if __name__ == "__main__":
    main()