custom.py 5.02 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
import os.path as osp

import mmcv
import numpy as np
from torch.utils.data import Dataset

7
from .pipelines import Compose
8
from .registry import DATASETS
Kai Chen's avatar
Kai Chen committed
9
10


11
@DATASETS.register_module
Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22
23
class CustomDataset(Dataset):
    """Custom dataset for detection.

    Annotation format:
    [
        {
            'filename': 'a.jpg',
            'width': 1280,
            'height': 720,
            'ann': {
                'bboxes': <np.ndarray> (n, 4),
                'labels': <np.ndarray> (n, ),
24
                'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
Kai Chen's avatar
Kai Chen committed
25
26
27
28
29
30
31
32
33
                'labels_ignore': <np.ndarray> (k, 4) (optional field)
            }
        },
        ...
    ]

    The `ann` field is optional for testing.
    """

34
35
    CLASSES = None

Kai Chen's avatar
Kai Chen committed
36
37
    def __init__(self,
                 ann_file,
38
39
                 pipeline,
                 data_root=None,
40
                 img_prefix='',
41
                 seg_prefix=None,
42
                 proposal_file=None,
43
44
                 test_mode=False,
                 filter_empty_gt=True):
45
46
        self.ann_file = ann_file
        self.data_root = data_root
47
        self.img_prefix = img_prefix
48
49
50
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
51
        self.filter_empty_gt = filter_empty_gt
yhcao6's avatar
yhcao6 committed
52

53
54
55
56
57
58
59
60
61
62
63
64
        # join paths if data_root is specified
        if self.data_root is not None:
            if not osp.isabs(self.ann_file):
                self.ann_file = osp.join(self.data_root, self.ann_file)
            if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
                self.img_prefix = osp.join(self.data_root, self.img_prefix)
            if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
                self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
            if not (self.proposal_file is None
                    or osp.isabs(self.proposal_file)):
                self.proposal_file = osp.join(self.data_root,
                                              self.proposal_file)
Kai Chen's avatar
Kai Chen committed
65
        # load annotations (and proposals)
66
67
68
        self.img_infos = self.load_annotations(self.ann_file)
        if self.proposal_file is not None:
            self.proposals = self.load_proposals(self.proposal_file)
Kai Chen's avatar
Kai Chen committed
69
70
        else:
            self.proposals = None
71
        # filter images too small
Kai Chen's avatar
Kai Chen committed
72
73
74
75
76
77
78
79
        if not test_mode:
            valid_inds = self._filter_imgs()
            self.img_infos = [self.img_infos[i] for i in valid_inds]
            if self.proposals is not None:
                self.proposals = [self.proposals[i] for i in valid_inds]
        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()
80
81
        # processing pipeline
        self.pipeline = Compose(pipeline)
82

Kai Chen's avatar
Kai Chen committed
83
84
85
86
87
88
89
90
91
92
93
94
    def __len__(self):
        return len(self.img_infos)

    def load_annotations(self, ann_file):
        return mmcv.load(ann_file)

    def load_proposals(self, proposal_file):
        return mmcv.load(proposal_file)

    def get_ann_info(self, idx):
        return self.img_infos[idx]['ann']

95
96
97
98
99
100
    def pre_pipeline(self, results):
        results['img_prefix'] = self.img_prefix
        results['seg_prefix'] = self.seg_prefix
        results['proposal_file'] = self.proposal_file
        results['bbox_fields'] = []
        results['mask_fields'] = []
101
        results['seg_fields'] = []
102

Kai Chen's avatar
Kai Chen committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def _filter_imgs(self, min_size=32):
        """Filter images too small."""
        valid_inds = []
        for i, img_info in enumerate(self.img_infos):
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)
        for i in range(len(self)):
            img_info = self.img_infos[i]
            if img_info['width'] / img_info['height'] > 1:
                self.flag[i] = 1

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(idx)
        while True:
            data = self.prepare_train_img(idx)
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
139
140
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
Kai Chen's avatar
Kai Chen committed
141
        if self.proposals is not None:
142
143
144
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)
Kai Chen's avatar
Kai Chen committed
145
146
147

    def prepare_test_img(self, idx):
        img_info = self.img_infos[idx]
148
        results = dict(img_info=img_info)
Kai Chen's avatar
Kai Chen committed
149
        if self.proposals is not None:
150
151
152
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)