query_denoising.py 9.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import BaseModule
from mmdet.core import bbox_xyxy_to_cxcywh
from mmdet.models.utils.transformer import inverse_sigmoid


class DnQueryGenerator(BaseModule):

    def __init__(self,
                 num_queries,
                 hidden_dim,
                 num_classes,
                 noise_scale=dict(label=0.5, box=0.4),
                 group_cfg=dict(
                     dynamic=True, num_groups=None, num_dn_queries=None)):
        super(DnQueryGenerator, self).__init__()
        self.num_queries = num_queries
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.label_noise_scale = noise_scale['label']
        self.box_noise_scale = noise_scale['box']
        self.dynamic_dn_groups = group_cfg.get('dynamic', False)
        if self.dynamic_dn_groups:
            assert 'num_dn_queries' in group_cfg, \
                'num_dn_queries should be set when using ' \
                'dynamic dn groups'
            self.num_dn = group_cfg['num_dn_queries']
        else:
            assert 'num_groups' in group_cfg, \
                'num_groups should be set when using ' \
                'static dn groups'
            self.num_dn = group_cfg['num_groups']
        assert isinstance(self.num_dn, int) and self.num_dn >= 1, \
            f'Expected the num in group_cfg to have type int. ' \
            f'Found {type(self.num_dn)} '

    def get_num_groups(self, group_queries=None):
        """
        Args:
            group_queries (int): Number of dn queries in one group.
        """
        if self.dynamic_dn_groups:
            assert group_queries is not None, \
                'group_queries should be provided when using ' \
                'dynamic dn groups'
            if group_queries == 0:
                num_groups = 1
            else:
                num_groups = self.num_dn // group_queries
        else:
            num_groups = self.num_dn
        if num_groups < 1: # avoid num_groups < 1 in query generator
            num_groups = 1
        return int(num_groups)

    def forward(self,
                gt_bboxes,
                gt_labels=None,
                label_enc=None,
                img_metas=None):
        """
        Args:
            gt_bboxes (List[Tensor]): List of ground truth bboxes
                of the image, shape of each (num_gts, 4).
            gt_labels (List[Tensor]): List of ground truth labels
                of the image, shape of each (num_gts,), if None,
                TODO:noisy_label would be None.
        Returns:
            TODO
        """
        # TODO: temp only support for CDN
        # TODO: temp assert gt_labels is not None and label_enc is not None

        if self.training:
            if gt_labels is not None:
                assert len(gt_bboxes) == len(gt_labels), \
                    f'the length of provided gt_labels ' \
                    f'{len(gt_labels)} should be equal to' \
                    f' that of gt_bboxes {len(gt_bboxes)}'
            assert gt_labels is not None \
                   and label_enc is not None \
                   and img_metas is not None  # TODO: adjust args
            batch_size = len(gt_bboxes)

            # convert bbox
            gt_bboxes_list = []
            for img_meta, bboxes in zip(img_metas, gt_bboxes):
                img_h, img_w, _ = img_meta['img_shape']
                factor = bboxes.new_tensor([img_w, img_h, img_w,
                                            img_h]).unsqueeze(0)
                bboxes_normalized = bbox_xyxy_to_cxcywh(bboxes) / factor
                gt_bboxes_list.append(bboxes_normalized)
            gt_bboxes = gt_bboxes_list

            known = [torch.ones_like(labels) for labels in gt_labels]
            known_num = [sum(k) for k in known]

            num_groups = self.get_num_groups(int(max(known_num)))

            unmask_bbox = unmask_label = torch.cat(known)
            labels = torch.cat(gt_labels)
            boxes = torch.cat(gt_bboxes)
            batch_idx = torch.cat([
                torch.full_like(t.long(), i) for i, t in enumerate(gt_labels)
            ])

            known_indice = torch.nonzero(unmask_label + unmask_bbox)
            known_indice = known_indice.view(-1)

            known_indice = known_indice.repeat(2 * num_groups, 1).view(-1)
            known_labels = labels.repeat(2 * num_groups, 1).view(-1)
            known_bid = batch_idx.repeat(2 * num_groups, 1).view(-1)
            known_bboxs = boxes.repeat(2 * num_groups, 1)
            known_labels_expand = known_labels.clone()
            known_bbox_expand = known_bboxs.clone()

            if self.label_noise_scale > 0:
                p = torch.rand_like(known_labels_expand.float())
                chosen_indice = torch.nonzero(
                    p < (self.label_noise_scale * 0.5)).view(-1)
                new_label = torch.randint_like(chosen_indice, 0,
                                               self.num_classes)
                known_labels_expand.scatter_(0, chosen_indice, new_label)
            single_pad = int(max(known_num))  # TODO

            pad_size = int(single_pad * 2 * num_groups)
            positive_idx = torch.tensor(range(
                len(boxes))).long().cuda().unsqueeze(0).repeat(num_groups, 1)
            positive_idx += (torch.tensor(range(num_groups)) * len(boxes) *
                             2).long().cuda().unsqueeze(1)
            positive_idx = positive_idx.flatten()
            negative_idx = positive_idx + len(boxes)
            if self.box_noise_scale > 0:
                known_bbox_ = torch.zeros_like(known_bboxs)
                known_bbox_[:, : 2] = \
                    known_bboxs[:, : 2] - known_bboxs[:, 2:] / 2
                known_bbox_[:, 2:] = \
                    known_bboxs[:, :2] + known_bboxs[:, 2:] / 2

                diff = torch.zeros_like(known_bboxs)
                diff[:, :2] = known_bboxs[:, 2:] / 2
                diff[:, 2:] = known_bboxs[:, 2:] / 2

                rand_sign = torch.randint_like(
                    known_bboxs, low=0, high=2, dtype=torch.float32)
                rand_sign = rand_sign * 2.0 - 1.0
                rand_part = torch.rand_like(known_bboxs)
                rand_part[negative_idx] += 1.0
                rand_part *= rand_sign
                known_bbox_ += \
                    torch.mul(rand_part, diff).cuda() * self.box_noise_scale
                known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
                known_bbox_expand[:, :2] = \
                    (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
                known_bbox_expand[:, 2:] = \
                    known_bbox_[:, 2:] - known_bbox_[:, :2]

            m = known_labels_expand.long().to('cuda')
            input_label_embed = label_enc(m)
            input_bbox_embed = inverse_sigmoid(known_bbox_expand, eps=1e-3)
zhe chen's avatar
zhe chen committed
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            padding_label = torch.zeros(pad_size, self.hidden_dim).cuda()
            padding_bbox = torch.zeros(pad_size, 4).cuda()

            input_query_label = padding_label.repeat(batch_size, 1, 1)
            input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)

            map_known_indice = torch.tensor([]).to('cuda')
            if len(known_num):
                map_known_indice = torch.cat(
                    [torch.tensor(range(num)) for num in known_num])
                map_known_indice = torch.cat([
                    map_known_indice + single_pad * i
                    for i in range(2 * num_groups)
                ]).long()
            if len(known_bid):
                input_query_label[(known_bid.long(),
                                   map_known_indice)] = input_label_embed
                input_query_bbox[(known_bid.long(),
                                  map_known_indice)] = input_bbox_embed

            tgt_size = pad_size + self.num_queries
            attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
            # match query cannot see the reconstruct
            attn_mask[pad_size:, :pad_size] = True
            # reconstruct cannot see each other
            for i in range(num_groups):
                if i == 0:
                    attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
                              single_pad * 2 * (i + 1):pad_size] = True
                if i == num_groups - 1:
                    attn_mask[single_pad * 2 * i:single_pad * 2 *
                              (i + 1), :single_pad * i * 2] = True
                else:
                    attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
                              single_pad * 2 * (i + 1):pad_size] = True
                    attn_mask[single_pad * 2 * i:single_pad * 2 *
                              (i + 1), :single_pad * 2 * i] = True

            dn_meta = {
                'pad_size': pad_size,
                'num_dn_group': num_groups,
            }
        else:
            input_query_label = None
            input_query_bbox = None
            attn_mask = None
            dn_meta = None
        return input_query_label, input_query_bbox, attn_mask, dn_meta


class CdnQueryGenerator(DnQueryGenerator):

    def __init__(self, *args, **kwargs):
        super(CdnQueryGenerator, self).__init__(*args, **kwargs)


def build_dn_generator(dn_args):
    """
    Args:
        dn_args (dict):
    Returns:
    """
    if dn_args is None:
        return None
    type = dn_args.pop('type')
    if type == 'DnQueryGenerator':
        return DnQueryGenerator(**dn_args)
    elif type == 'CdnQueryGenerator':
        return CdnQueryGenerator(**dn_args)
    else:
zhe chen's avatar
zhe chen committed
233
        raise NotImplementedError(f'{type} is not supported yet')