custom.py 4.93 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
import os.path as osp

import mmcv
import numpy as np
from torch.utils.data import Dataset

7
from .pipelines import Compose
8
from .registry import DATASETS
Kai Chen's avatar
Kai Chen committed
9
10


11
@DATASETS.register_module
Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22
23
class CustomDataset(Dataset):
    """Custom dataset for detection.

    Annotation format:
    [
        {
            'filename': 'a.jpg',
            'width': 1280,
            'height': 720,
            'ann': {
                'bboxes': <np.ndarray> (n, 4),
                'labels': <np.ndarray> (n, ),
24
                'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
Kai Chen's avatar
Kai Chen committed
25
26
27
28
29
30
31
32
33
                'labels_ignore': <np.ndarray> (k, 4) (optional field)
            }
        },
        ...
    ]

    The `ann` field is optional for testing.
    """

34
35
    CLASSES = None

Kai Chen's avatar
Kai Chen committed
36
37
    def __init__(self,
                 ann_file,
38
39
                 pipeline,
                 data_root=None,
40
                 img_prefix='',
41
                 seg_prefix=None,
42
                 proposal_file=None,
43
                 test_mode=False):
44
45
        self.ann_file = ann_file
        self.data_root = data_root
46
        self.img_prefix = img_prefix
47
48
49
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
yhcao6's avatar
yhcao6 committed
50

51
52
53
54
55
56
57
58
59
60
61
62
        # join paths if data_root is specified
        if self.data_root is not None:
            if not osp.isabs(self.ann_file):
                self.ann_file = osp.join(self.data_root, self.ann_file)
            if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
                self.img_prefix = osp.join(self.data_root, self.img_prefix)
            if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
                self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
            if not (self.proposal_file is None
                    or osp.isabs(self.proposal_file)):
                self.proposal_file = osp.join(self.data_root,
                                              self.proposal_file)
Kai Chen's avatar
Kai Chen committed
63
        # load annotations (and proposals)
64
65
66
        self.img_infos = self.load_annotations(self.ann_file)
        if self.proposal_file is not None:
            self.proposals = self.load_proposals(self.proposal_file)
Kai Chen's avatar
Kai Chen committed
67
68
69
70
71
72
73
74
75
76
77
        else:
            self.proposals = None
        # filter images with no annotation during training
        if not test_mode:
            valid_inds = self._filter_imgs()
            self.img_infos = [self.img_infos[i] for i in valid_inds]
            if self.proposals is not None:
                self.proposals = [self.proposals[i] for i in valid_inds]
        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()
78
79
        # processing pipeline
        self.pipeline = Compose(pipeline)
80

Kai Chen's avatar
Kai Chen committed
81
82
83
84
85
86
87
88
89
90
91
92
    def __len__(self):
        return len(self.img_infos)

    def load_annotations(self, ann_file):
        return mmcv.load(ann_file)

    def load_proposals(self, proposal_file):
        return mmcv.load(proposal_file)

    def get_ann_info(self, idx):
        return self.img_infos[idx]['ann']

93
94
95
96
97
98
99
    def pre_pipeline(self, results):
        results['img_prefix'] = self.img_prefix
        results['seg_prefix'] = self.seg_prefix
        results['proposal_file'] = self.proposal_file
        results['bbox_fields'] = []
        results['mask_fields'] = []

Kai Chen's avatar
Kai Chen committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    def _filter_imgs(self, min_size=32):
        """Filter images too small."""
        valid_inds = []
        for i, img_info in enumerate(self.img_infos):
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)
        for i in range(len(self)):
            img_info = self.img_infos[i]
            if img_info['width'] / img_info['height'] > 1:
                self.flag[i] = 1

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(idx)
        while True:
            data = self.prepare_train_img(idx)
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
136
137
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
Kai Chen's avatar
Kai Chen committed
138
        if self.proposals is not None:
139
140
141
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)
Kai Chen's avatar
Kai Chen committed
142
143
144

    def prepare_test_img(self, idx):
        img_info = self.img_infos[idx]
145
        results = dict(img_info=img_info)
Kai Chen's avatar
Kai Chen committed
146
        if self.proposals is not None:
147
148
149
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)