utils.py 15.9 KB
Newer Older
dlyrm's avatar
dlyrm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from detrex (https://github.com/IDEA-Research/detrex)
# Copyright 2022 The IDEA Authors. All rights reserved.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from ..bbox_utils import bbox_overlaps

__all__ = [
    '_get_clones', 'bbox_overlaps', 'bbox_cxcywh_to_xyxy',
    'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss', 'inverse_sigmoid',
    'deformable_attention_core_func', 'varifocal_loss_with_logits'
]


def _get_clones(module, N):
    return nn.LayerList([copy.deepcopy(module) for _ in range(N)])


def bbox_cxcywh_to_xyxy(x):
    cxcy, wh = paddle.split(x, 2, axis=-1)
    return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)


def bbox_xyxy_to_cxcywh(x):
    x1, y1, x2, y2 = x.split(4, axis=-1)
    return paddle.concat(
        [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)


def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
    prob = F.sigmoid(logit)
    ce_loss = F.binary_cross_entropy_with_logits(logit, label, reduction="none")
    p_t = prob * label + (1 - prob) * (1 - label)
    loss = ce_loss * ((1 - p_t)**gamma)

    if alpha >= 0:
        alpha_t = alpha * label + (1 - alpha) * (1 - label)
        loss = alpha_t * loss
    return loss.mean(1).sum() / normalizer


def inverse_sigmoid(x, eps=1e-5):
    x = x.clip(min=0., max=1.)
    return paddle.log(x.clip(min=eps) / (1 - x).clip(min=eps))


def deformable_attention_core_func(value, value_spatial_shapes,
                                   value_level_start_index, sampling_locations,
                                   attention_weights):
    """
    Args:
        value (Tensor): [bs, value_length, n_head, c]
        value_spatial_shapes (Tensor|List): [n_levels, 2]
        value_level_start_index (Tensor|List): [n_levels]
        sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
        attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]

    Returns:
        output (Tensor): [bs, Length_{query}, C]
    """
    bs, _, n_head, c = value.shape
    _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape

    split_shape = [h * w for h, w in value_spatial_shapes]
    value_list = value.split(split_shape, axis=1)
    sampling_grids = 2 * sampling_locations - 1
    sampling_value_list = []
    for level, (h, w) in enumerate(value_spatial_shapes):
        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
        value_l_ = value_list[level].flatten(2).transpose(
            [0, 2, 1]).reshape([bs * n_head, c, h, w])
        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(
            [0, 2, 1, 3, 4]).flatten(0, 1)
        # N_*M_, D_, Lq_, P_
        sampling_value_l_ = F.grid_sample(
            value_l_,
            sampling_grid_l_,
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False)
        sampling_value_list.append(sampling_value_l_)
    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
    attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape(
        [bs * n_head, 1, Len_q, n_levels * n_points])
    output = (paddle.stack(
        sampling_value_list, axis=-2).flatten(-2) *
              attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])

    return output.transpose([0, 2, 1])


def get_valid_ratio(mask):
    _, H, W = paddle.shape(mask)
    valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
    valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
    # [b, 2]
    return paddle.stack([valid_ratio_w, valid_ratio_h], -1)


def get_denoising_training_group(targets,
                                 num_classes,
                                 num_queries,
                                 class_embed,
                                 num_denoising=100,
                                 label_noise_ratio=0.5,
                                 box_noise_scale=1.0):
    if num_denoising <= 0:
        return None, None, None, None
    num_gts = [len(t) for t in targets["gt_class"]]
    max_gt_num = max(num_gts)
    if max_gt_num == 0:
        return None, None, None, None

    num_group = num_denoising // max_gt_num
    num_group = 1 if num_group == 0 else num_group
    # pad gt to max_num of a batch
    bs = len(targets["gt_class"])
    input_query_class = paddle.full(
        [bs, max_gt_num], num_classes, dtype='int32')
    input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
    pad_gt_mask = paddle.zeros([bs, max_gt_num])
    for i in range(bs):
        num_gt = num_gts[i]
        if num_gt > 0:
            input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
            input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
            pad_gt_mask[i, :num_gt] = 1

    input_query_class = input_query_class.tile([1, num_group])
    input_query_bbox = input_query_bbox.tile([1, num_group, 1])
    pad_gt_mask = pad_gt_mask.tile([1, num_group])

    dn_positive_idx = paddle.nonzero(pad_gt_mask)[:, 1]
    dn_positive_idx = paddle.split(dn_positive_idx,
                                   [n * num_group for n in num_gts])
    # total denoising queries
    num_denoising = int(max_gt_num * num_group)

    if label_noise_ratio > 0:
        input_query_class = input_query_class.flatten()
        pad_gt_mask = pad_gt_mask.flatten()
        # half of bbox prob
        mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
        chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1)
        # randomly put a new one here
        new_label = paddle.randint_like(
            chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
        input_query_class.scatter_(chosen_idx, new_label)
        input_query_class.reshape_([bs, num_denoising])
        pad_gt_mask.reshape_([bs, num_denoising])

    if box_noise_scale > 0:
        diff = paddle.concat(
            [input_query_bbox[..., 2:] * 0.5, input_query_bbox[..., 2:]],
            axis=-1) * box_noise_scale
        diff *= (paddle.rand(input_query_bbox.shape) * 2.0 - 1.0)
        input_query_bbox += diff
        input_query_bbox = inverse_sigmoid(input_query_bbox)

    class_embed = paddle.concat(
        [class_embed, paddle.zeros([1, class_embed.shape[-1]])])
    input_query_class = paddle.gather(
        class_embed, input_query_class.flatten(),
        axis=0).reshape([bs, num_denoising, -1])

    tgt_size = num_denoising + num_queries
    attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
    # match query cannot see the reconstruction
    attn_mask[num_denoising:, :num_denoising] = True
    # reconstruct cannot see each other
    for i in range(num_group):
        if i == 0:
            attn_mask[max_gt_num * i:max_gt_num * (i + 1), max_gt_num * (i + 1):
                      num_denoising] = True
        if i == num_group - 1:
            attn_mask[max_gt_num * i:max_gt_num * (i + 1), :max_gt_num *
                      i] = True
        else:
            attn_mask[max_gt_num * i:max_gt_num * (i + 1), max_gt_num * (i + 1):
                      num_denoising] = True
            attn_mask[max_gt_num * i:max_gt_num * (i + 1), :max_gt_num *
                      i] = True
    attn_mask = ~attn_mask
    dn_meta = {
        "dn_positive_idx": dn_positive_idx,
        "dn_num_group": num_group,
        "dn_num_split": [num_denoising, num_queries]
    }

    return input_query_class, input_query_bbox, attn_mask, dn_meta


def get_contrastive_denoising_training_group(targets,
                                             num_classes,
                                             num_queries,
                                             class_embed,
                                             num_denoising=100,
                                             label_noise_ratio=0.5,
                                             box_noise_scale=1.0):
    if num_denoising <= 0:
        return None, None, None, None
    num_gts = [len(t) for t in targets["gt_class"]]
    max_gt_num = max(num_gts)
    if max_gt_num == 0:
        return None, None, None, None

    num_group = num_denoising // max_gt_num
    num_group = 1 if num_group == 0 else num_group
    # pad gt to max_num of a batch
    bs = len(targets["gt_class"])
    input_query_class = paddle.full(
        [bs, max_gt_num], num_classes, dtype='int32')
    input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
    pad_gt_mask = paddle.zeros([bs, max_gt_num])
    for i in range(bs):
        num_gt = num_gts[i]
        if num_gt > 0:
            input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
            input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
            pad_gt_mask[i, :num_gt] = 1
    # each group has positive and negative queries.
    input_query_class = input_query_class.tile([1, 2 * num_group])
    input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
    pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
    # positive and negative mask
    negative_gt_mask = paddle.zeros([bs, max_gt_num * 2, 1])
    negative_gt_mask[:, max_gt_num:] = 1
    negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
    positive_gt_mask = 1 - negative_gt_mask
    # contrastive denoising training positive index
    positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
    dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1]
    dn_positive_idx = paddle.split(dn_positive_idx,
                                   [n * num_group for n in num_gts])
    # total denoising queries
    num_denoising = int(max_gt_num * 2 * num_group)

    if label_noise_ratio > 0:
        input_query_class = input_query_class.flatten()
        pad_gt_mask = pad_gt_mask.flatten()
        # half of bbox prob
        mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
        chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1)
        # randomly put a new one here
        new_label = paddle.randint_like(
            chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
        input_query_class.scatter_(chosen_idx, new_label)
        input_query_class.reshape_([bs, num_denoising])
        pad_gt_mask.reshape_([bs, num_denoising])

    if box_noise_scale > 0:
        known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)

        diff = paddle.tile(input_query_bbox[..., 2:] * 0.5,
                           [1, 1, 2]) * box_noise_scale

        rand_sign = paddle.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
        rand_part = paddle.rand(input_query_bbox.shape)
        rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
            1 - negative_gt_mask)
        rand_part *= rand_sign
        known_bbox += rand_part * diff
        known_bbox.clip_(min=0.0, max=1.0)
        input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
        input_query_bbox = inverse_sigmoid(input_query_bbox)

    class_embed = paddle.concat(
        [class_embed, paddle.zeros([1, class_embed.shape[-1]])])
    input_query_class = paddle.gather(
        class_embed, input_query_class.flatten(),
        axis=0).reshape([bs, num_denoising, -1])

    tgt_size = num_denoising + num_queries
    attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
    # match query cannot see the reconstruction
    attn_mask[num_denoising:, :num_denoising] = True
    # reconstruct cannot see each other
    for i in range(num_group):
        if i == 0:
            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
                      2 * (i + 1):num_denoising] = True
        if i == num_group - 1:
            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
                      i * 2] = True
        else:
            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
                      2 * (i + 1):num_denoising] = True
            attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
                      2 * i] = True
    attn_mask = ~attn_mask
    dn_meta = {
        "dn_positive_idx": dn_positive_idx,
        "dn_num_group": num_group,
        "dn_num_split": [num_denoising, num_queries]
    }

    return input_query_class, input_query_bbox, attn_mask, dn_meta


def get_sine_pos_embed(pos_tensor,
                       num_pos_feats=128,
                       temperature=10000,
                       exchange_xy=True):
    """generate sine position embedding from a position tensor

    Args:
        pos_tensor (Tensor): Shape as `(None, n)`.
        num_pos_feats (int): projected shape for each float in the tensor. Default: 128
        temperature (int): The temperature used for scaling
            the position embedding. Default: 10000.
        exchange_xy (bool, optional): exchange pos x and pos y. \
            For example, input tensor is `[x, y]`, the results will  # noqa
            be `[pos(y), pos(x)]`. Defaults: True.

    Returns:
        Tensor: Returned position embedding  # noqa
        with shape `(None, n * num_pos_feats)`.
    """
    scale = 2. * math.pi
    dim_t = 2. * paddle.floor_divide(
        paddle.arange(num_pos_feats), paddle.to_tensor(2))
    dim_t = scale / temperature**(dim_t / num_pos_feats)

    def sine_func(x):
        x *= dim_t
        return paddle.stack(
            (x[:, :, 0::2].sin(), x[:, :, 1::2].cos()), axis=3).flatten(2)

    pos_res = [sine_func(x) for x in pos_tensor.split(pos_tensor.shape[-1], -1)]
    if exchange_xy:
        pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
    pos_res = paddle.concat(pos_res, axis=2)
    return pos_res


def mask_to_box_coordinate(mask,
                           normalize=False,
                           format="xyxy",
                           dtype="float32"):
    """
    Compute the bounding boxes around the provided mask.
    Args:
        mask (Tensor:bool): [b, c, h, w]

    Returns:
        bbox (Tensor): [b, c, 4]
    """
    assert mask.ndim == 4
    assert format in ["xyxy", "xywh"]
    if mask.sum() == 0:
        return paddle.zeros([mask.shape[0], mask.shape[1], 4], dtype=dtype)

    h, w = mask.shape[-2:]
    y, x = paddle.meshgrid(
        paddle.arange(
            end=h, dtype=dtype), paddle.arange(
                end=w, dtype=dtype))

    x_mask = x * mask
    x_max = x_mask.flatten(-2).max(-1) + 1
    x_min = paddle.where(mask, x_mask,
                         paddle.to_tensor(1e8)).flatten(-2).min(-1)

    y_mask = y * mask
    y_max = y_mask.flatten(-2).max(-1) + 1
    y_min = paddle.where(mask, y_mask,
                         paddle.to_tensor(1e8)).flatten(-2).min(-1)
    out_bbox = paddle.stack([x_min, y_min, x_max, y_max], axis=-1)
    if normalize:
        out_bbox /= paddle.to_tensor([w, h, w, h]).astype(dtype)

    return out_bbox if format == "xyxy" else bbox_xyxy_to_cxcywh(out_bbox)


def varifocal_loss_with_logits(pred_logits,
                               gt_score,
                               label,
                               normalizer=1.0,
                               alpha=0.75,
                               gamma=2.0):
    pred_score = F.sigmoid(pred_logits)
    weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
    loss = F.binary_cross_entropy_with_logits(
        pred_logits, gt_score, weight=weight, reduction='none')
    return loss.mean(1).sum() / normalizer