modeling_detr.py 109 KB
Newer Older
NielsRogge's avatar
NielsRogge committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" PyTorch DETR model."""
NielsRogge's avatar
NielsRogge committed
16
17
18
19
20
21
22
23
24
25
26


import math
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor, nn

from ...activations import ACT2FN
27
28
29
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
NielsRogge's avatar
NielsRogge committed
30
31
32
33
34
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_scipy_available,
    is_timm_available,
35
    is_vision_available,
36
    logging,
NielsRogge's avatar
NielsRogge committed
37
38
39
    replace_return_docstrings,
    requires_backends,
)
40
from ..auto import AutoBackbone
NielsRogge's avatar
NielsRogge committed
41
42
43
44
45
46
47
48
49
from .configuration_detr import DetrConfig


if is_scipy_available():
    from scipy.optimize import linear_sum_assignment

if is_timm_available():
    from timm import create_model

50
51
52
if is_vision_available():
    from transformers.image_transforms import center_to_corners_format

NielsRogge's avatar
NielsRogge committed
53
54
55
logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DetrConfig"
56
_CHECKPOINT_FOR_DOC = "facebook/detr-resnet-50"
NielsRogge's avatar
NielsRogge committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/detr-resnet-50",
    # See all DETR models at https://huggingface.co/models?filter=detr
]


@dataclass
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
    """
    Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
    namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
    gone through a layernorm. This is useful when training the model with auxiliary decoding losses.

    Args:
72
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
NielsRogge's avatar
NielsRogge committed
73
            Sequence of hidden-states at the output of the last layer of the model.
74
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
75
76
77
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
78
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
79
80
81
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
82
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
83
84
85
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
86
        intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
NielsRogge's avatar
NielsRogge committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
            layernorm.
    """

    intermediate_hidden_states: Optional[torch.FloatTensor] = None


@dataclass
class DetrModelOutput(Seq2SeqModelOutput):
    """
    Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
    namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
    gone through a layernorm. This is useful when training the model with auxiliary decoding losses.

    Args:
102
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
NielsRogge's avatar
NielsRogge committed
103
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
104
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
105
106
107
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
            layer plus the initial embedding outputs.
108
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
111
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
112
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
115
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
116
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
117
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
118
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
119
120
121
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
            layer plus the initial embedding outputs.
122
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
123
124
125
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
126
        intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
NielsRogge's avatar
NielsRogge committed
127
128
129
130
131
132
133
134
135
136
            Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
            layernorm.
    """

    intermediate_hidden_states: Optional[torch.FloatTensor] = None


@dataclass
class DetrObjectDetectionOutput(ModelOutput):
    """
137
    Output type of [`DetrForObjectDetection`].
NielsRogge's avatar
NielsRogge committed
138
139

    Args:
140
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
NielsRogge's avatar
NielsRogge committed
141
142
143
            Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
            bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
            scale-invariant IoU loss.
144
        loss_dict (`Dict`, *optional*):
NielsRogge's avatar
NielsRogge committed
145
            A dictionary containing the individual losses. Useful for logging.
146
        logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
NielsRogge's avatar
NielsRogge committed
147
            Classification logits (including no-object) for all queries.
148
        pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
NielsRogge's avatar
NielsRogge committed
149
150
            Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
            values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
151
            possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
152
            unnormalized bounding boxes.
153
        auxiliary_outputs (`list[Dict]`, *optional*):
154
            Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
Sylvain Gugger's avatar
Sylvain Gugger committed
155
156
            and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
            `pred_boxes`) for each decoder layer.
157
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
158
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
159
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
160
161
162
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
            layer plus the initial embedding outputs.
163
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
164
165
166
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
167
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
168
169
170
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
171
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
172
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
173
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
174
175
176
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
            layer plus the initial embedding outputs.
177
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
178
179
180
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
NielsRogge's avatar
NielsRogge committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    """

    loss: Optional[torch.FloatTensor] = None
    loss_dict: Optional[Dict] = None
    logits: torch.FloatTensor = None
    pred_boxes: torch.FloatTensor = None
    auxiliary_outputs: Optional[List[Dict]] = None
    last_hidden_state: Optional[torch.FloatTensor] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class DetrSegmentationOutput(ModelOutput):
    """
200
    Output type of [`DetrForSegmentation`].
NielsRogge's avatar
NielsRogge committed
201
202

    Args:
203
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
NielsRogge's avatar
NielsRogge committed
204
205
206
            Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
            bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
            scale-invariant IoU loss.
207
        loss_dict (`Dict`, *optional*):
NielsRogge's avatar
NielsRogge committed
208
            A dictionary containing the individual losses. Useful for logging.
209
        logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
NielsRogge's avatar
NielsRogge committed
210
            Classification logits (including no-object) for all queries.
211
        pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
NielsRogge's avatar
NielsRogge committed
212
213
            Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
            values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
214
            possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
215
            unnormalized bounding boxes.
216
        pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
217
            Segmentation masks logits for all queries. See also
218
219
220
            [`~DetrImageProcessor.post_process_semantic_segmentation`] or
            [`~DetrImageProcessor.post_process_instance_segmentation`]
            [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
221
            segmentation masks respectively.
222
        auxiliary_outputs (`list[Dict]`, *optional*):
223
            Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
Sylvain Gugger's avatar
Sylvain Gugger committed
224
225
            and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
            `pred_boxes`) for each decoder layer.
226
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
227
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
228
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
231
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
            layer plus the initial embedding outputs.
232
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
233
234
235
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
236
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
237
238
239
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
240
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
241
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
242
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
243
244
245
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
            layer plus the initial embedding outputs.
246
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
249
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
            weighted average in the self-attention heads.
NielsRogge's avatar
NielsRogge committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    """

    loss: Optional[torch.FloatTensor] = None
    loss_dict: Optional[Dict] = None
    logits: torch.FloatTensor = None
    pred_boxes: torch.FloatTensor = None
    pred_masks: torch.FloatTensor = None
    auxiliary_outputs: Optional[List[Dict]] = None
    last_hidden_state: Optional[torch.FloatTensor] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None


# BELOW: utilities copied from
# https://github.com/facebookresearch/detr/blob/master/backbone.py
class DetrFrozenBatchNorm2d(nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
    torchvision.models.resnet[18,34,50,101] produce nans.
    """

    def __init__(self, n):
NielsRogge's avatar
NielsRogge committed
278
        super().__init__()
NielsRogge's avatar
NielsRogge committed
279
280
281
282
283
284
285
286
287
288
289
290
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        num_batches_tracked_key = prefix + "num_batches_tracked"
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

NielsRogge's avatar
NielsRogge committed
291
        super()._load_from_state_dict(
NielsRogge's avatar
NielsRogge committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

    def forward(self, x):
        # move reshapes to the beginning
        # to make it user-friendly
        weight = self.weight.reshape(1, -1, 1, 1)
        bias = self.bias.reshape(1, -1, 1, 1)
        running_var = self.running_var.reshape(1, -1, 1, 1)
        running_mean = self.running_mean.reshape(1, -1, 1, 1)
        epsilon = 1e-5
        scale = weight * (running_var + epsilon).rsqrt()
        bias = bias - running_mean * scale
        return x * scale + bias


def replace_batch_norm(m, name=""):
    for attr_str in dir(m):
        target_attr = getattr(m, attr_str)
311
        if isinstance(target_attr, nn.BatchNorm2d):
NielsRogge's avatar
NielsRogge committed
312
313
314
315
316
317
318
319
320
321
322
            frozen = DetrFrozenBatchNorm2d(target_attr.num_features)
            bn = getattr(m, attr_str)
            frozen.weight.data.copy_(bn.weight)
            frozen.bias.data.copy_(bn.bias)
            frozen.running_mean.data.copy_(bn.running_mean)
            frozen.running_var.data.copy_(bn.running_var)
            setattr(m, attr_str, frozen)
    for n, ch in m.named_children():
        replace_batch_norm(ch, n)


323
class DetrConvEncoder(nn.Module):
NielsRogge's avatar
NielsRogge committed
324
    """
325
    Convolutional backbone, using either the AutoBackbone API or one from the timm library.
NielsRogge's avatar
NielsRogge committed
326
327
328
329
330

    nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.

    """

331
    def __init__(self, config):
NielsRogge's avatar
NielsRogge committed
332
333
        super().__init__()

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        self.config = config

        if config.use_timm_backbone:
            requires_backends(self, ["timm"])
            kwargs = {}
            if config.dilation:
                kwargs["output_stride"] = 16
            backbone = create_model(
                config.backbone,
                pretrained=config.use_pretrained_backbone,
                features_only=True,
                out_indices=(1, 2, 3, 4),
                in_chans=config.num_channels,
                **kwargs,
            )
        else:
            backbone = AutoBackbone.from_config(config.backbone_config)
NielsRogge's avatar
NielsRogge committed
351
352
353
354
355

        # replace batch norm by frozen batch norm
        with torch.no_grad():
            replace_batch_norm(backbone)
        self.model = backbone
356
357
358
        self.intermediate_channel_sizes = (
            self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
        )
NielsRogge's avatar
NielsRogge committed
359

360
361
        backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
        if "resnet" in backbone_model_type:
NielsRogge's avatar
NielsRogge committed
362
            for name, parameter in self.model.named_parameters():
363
364
365
366
367
368
                if config.use_timm_backbone:
                    if "layer2" not in name and "layer3" not in name and "layer4" not in name:
                        parameter.requires_grad_(False)
                else:
                    if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
                        parameter.requires_grad_(False)
NielsRogge's avatar
NielsRogge committed
369
370
371

    def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
        # send pixel_values through the model to get list of feature maps
372
        features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
NielsRogge's avatar
NielsRogge committed
373
374
375
376

        out = []
        for feature_map in features:
            # downsample pixel_mask to match shape of corresponding feature_map
377
            mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
NielsRogge's avatar
NielsRogge committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            out.append((feature_map, mask))
        return out


class DetrConvModel(nn.Module):
    """
    This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
    """

    def __init__(self, conv_encoder, position_embedding):
        super().__init__()
        self.conv_encoder = conv_encoder
        self.position_embedding = position_embedding

    def forward(self, pixel_values, pixel_mask):
        # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
        out = self.conv_encoder(pixel_values, pixel_mask)
        pos = []
        for feature_map, mask in out:
            # position encoding
            pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))

        return out, pos


NielsRogge's avatar
NielsRogge committed
403
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
NielsRogge's avatar
NielsRogge committed
404
    """
NielsRogge's avatar
NielsRogge committed
405
    Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
NielsRogge's avatar
NielsRogge committed
406
    """
NielsRogge's avatar
NielsRogge committed
407
408
    batch_size, source_len = mask.size()
    target_len = target_len if target_len is not None else source_len
NielsRogge's avatar
NielsRogge committed
409

NielsRogge's avatar
NielsRogge committed
410
    expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
NielsRogge's avatar
NielsRogge committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)


class DetrSinePositionEmbedding(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
    need paper, generalized to work on images.
    """

    def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, pixel_values, pixel_mask):
435
436
        if pixel_mask is None:
            raise ValueError("No pixel mask provided")
NielsRogge's avatar
NielsRogge committed
437
438
439
440
441
442
443
        y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
        x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale

        dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
444
        dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
NielsRogge's avatar
NielsRogge committed
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class DetrLearnedPositionEmbedding(nn.Module):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, embedding_dim=256):
        super().__init__()
        self.row_embeddings = nn.Embedding(50, embedding_dim)
        self.column_embeddings = nn.Embedding(50, embedding_dim)

    def forward(self, pixel_values, pixel_mask=None):
NielsRogge's avatar
NielsRogge committed
465
466
467
468
469
470
        height, width = pixel_values.shape[-2:]
        width_values = torch.arange(width, device=pixel_values.device)
        height_values = torch.arange(height, device=pixel_values.device)
        x_emb = self.column_embeddings(width_values)
        y_emb = self.row_embeddings(height_values)
        pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
NielsRogge's avatar
NielsRogge committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        pos = pos.permute(2, 0, 1)
        pos = pos.unsqueeze(0)
        pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
        return pos


def build_position_encoding(config):
    n_steps = config.d_model // 2
    if config.position_embedding_type == "sine":
        # TODO find a better way of exposing other arguments
        position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
    elif config.position_embedding_type == "learned":
        position_embedding = DetrLearnedPositionEmbedding(n_steps)
    else:
        raise ValueError(f"Not supported {config.position_embedding_type}")

    return position_embedding


class DetrAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper.

    Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
510
511
        if self.head_dim * num_heads != self.embed_dim:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
512
513
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {num_heads})."
514
            )
515
        self.scaling = self.head_dim**-0.5
NielsRogge's avatar
NielsRogge committed
516
517
518
519
520
521

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

NielsRogge's avatar
NielsRogge committed
522
523
    def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
        return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
NielsRogge's avatar
NielsRogge committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541

    def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
        return tensor if position_embeddings is None else tensor + position_embeddings

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_embeddings: Optional[torch.Tensor] = None,
        key_value_states: Optional[torch.Tensor] = None,
        key_value_position_embeddings: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
NielsRogge's avatar
NielsRogge committed
542
        batch_size, target_len, embed_dim = hidden_states.size()
NielsRogge's avatar
NielsRogge committed
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

        # add position embeddings to the hidden states before projecting to queries and keys
        if position_embeddings is not None:
            hidden_states_original = hidden_states
            hidden_states = self.with_pos_embed(hidden_states, position_embeddings)

        # add key-value position embeddings to the key value states
        if key_value_position_embeddings is not None:
            key_value_states_original = key_value_states
            key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
NielsRogge's avatar
NielsRogge committed
559
560
            key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
            value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
NielsRogge's avatar
NielsRogge committed
561
562
        else:
            # self_attention
NielsRogge's avatar
NielsRogge committed
563
564
            key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
            value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
NielsRogge's avatar
NielsRogge committed
565

NielsRogge's avatar
NielsRogge committed
566
567
        proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
NielsRogge's avatar
NielsRogge committed
568
569
570
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

NielsRogge's avatar
NielsRogge committed
571
        source_len = key_states.size(1)
NielsRogge's avatar
NielsRogge committed
572
573
574

        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

NielsRogge's avatar
NielsRogge committed
575
        if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
NielsRogge's avatar
NielsRogge committed
576
            raise ValueError(
NielsRogge's avatar
NielsRogge committed
577
                f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
Sylvain Gugger's avatar
Sylvain Gugger committed
578
                f" {attn_weights.size()}"
NielsRogge's avatar
NielsRogge committed
579
580
581
            )

        if attention_mask is not None:
NielsRogge's avatar
NielsRogge committed
582
            if attention_mask.size() != (batch_size, 1, target_len, source_len):
NielsRogge's avatar
NielsRogge committed
583
                raise ValueError(
NielsRogge's avatar
NielsRogge committed
584
585
                    f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
                    f" {attention_mask.size()}"
NielsRogge's avatar
NielsRogge committed
586
                )
NielsRogge's avatar
NielsRogge committed
587
588
            attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
            attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
NielsRogge's avatar
NielsRogge committed
589

590
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
NielsRogge's avatar
NielsRogge committed
591
592
593
594
595
596

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
NielsRogge's avatar
NielsRogge committed
597
598
            attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
            attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
NielsRogge's avatar
NielsRogge committed
599
600
601
        else:
            attn_weights_reshaped = None

602
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
603
604
605

        attn_output = torch.bmm(attn_probs, value_states)

NielsRogge's avatar
NielsRogge committed
606
        if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
NielsRogge's avatar
NielsRogge committed
607
            raise ValueError(
NielsRogge's avatar
NielsRogge committed
608
                f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
Sylvain Gugger's avatar
Sylvain Gugger committed
609
                f" {attn_output.size()}"
NielsRogge's avatar
NielsRogge committed
610
611
            )

NielsRogge's avatar
NielsRogge committed
612
        attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
NielsRogge's avatar
NielsRogge committed
613
        attn_output = attn_output.transpose(1, 2)
NielsRogge's avatar
NielsRogge committed
614
        attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
NielsRogge's avatar
NielsRogge committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped


class DetrEncoderLayer(nn.Module):
    def __init__(self, config: DetrConfig):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = DetrAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        position_embeddings: torch.Tensor = None,
        output_attentions: bool = False,
    ):
        """
        Args:
647
648
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
NielsRogge's avatar
NielsRogge committed
649
650
                `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
                values.
651
652
653
            position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
NielsRogge's avatar
NielsRogge committed
654
655
656
657
658
659
660
661
662
663
                returned tensors for more detail.
        """
        residual = hidden_states
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_embeddings=position_embeddings,
            output_attentions=output_attentions,
        )

664
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
665
666
667
668
669
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
670
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
671
672

        hidden_states = self.fc2(hidden_states)
673
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
674
675
676
677

        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

678
679
680
681
        if self.training:
            if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
NielsRogge's avatar
NielsRogge committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs


class DetrDecoderLayer(nn.Module):
    def __init__(self, config: DetrConfig):
        super().__init__()
        self.embed_dim = config.d_model

        self.self_attn = DetrAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.encoder_attn = DetrAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_embeddings: Optional[torch.Tensor] = None,
        query_position_embeddings: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ):
        """
        Args:
730
731
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
NielsRogge's avatar
NielsRogge committed
732
733
                `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
                values.
Sylvain Gugger's avatar
Sylvain Gugger committed
734
735
            position_embeddings (`torch.FloatTensor`, *optional*):
                position embeddings that are added to the queries and keys
NielsRogge's avatar
NielsRogge committed
736
            in the cross-attention layer.
Sylvain Gugger's avatar
Sylvain Gugger committed
737
738
            query_position_embeddings (`torch.FloatTensor`, *optional*):
                position embeddings that are added to the queries and keys
NielsRogge's avatar
NielsRogge committed
739
            in the self-attention layer.
Sylvain Gugger's avatar
Sylvain Gugger committed
740
741
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
742
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
NielsRogge's avatar
NielsRogge committed
743
744
                `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
                values.
745
746
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
NielsRogge's avatar
NielsRogge committed
747
748
749
750
751
752
753
754
755
756
757
758
                returned tensors for more detail.
        """
        residual = hidden_states

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=query_position_embeddings,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
        )

759
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Cross-Attention Block
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            residual = hidden_states

            hidden_states, cross_attn_weights = self.encoder_attn(
                hidden_states=hidden_states,
                position_embeddings=query_position_embeddings,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                key_value_position_embeddings=position_embeddings,
                output_attentions=output_attentions,
            )

777
            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
778
779
780
781
782
783
            hidden_states = residual + hidden_states
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
784
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
785
        hidden_states = self.fc2(hidden_states)
786
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        return outputs


class DetrClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, input_dim: int, inner_dim: int, num_classes: int, pooler_dropout: float):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, hidden_states: torch.Tensor):
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


class DetrPreTrainedModel(PreTrainedModel):
    config_class = DetrConfig
    base_model_prefix = "model"
819
    main_input_name = "pixel_values"
NielsRogge's avatar
NielsRogge committed
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843

    def _init_weights(self, module):
        std = self.config.init_std
        xavier_std = self.config.init_xavier_std

        if isinstance(module, DetrMHAttentionMap):
            nn.init.zeros_(module.k_linear.bias)
            nn.init.zeros_(module.q_linear.bias)
            nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
            nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
        elif isinstance(module, DetrLearnedPositionEmbedding):
            nn.init.uniform_(module.row_embeddings.weight)
            nn.init.uniform_(module.column_embeddings.weight)
        if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

844
845
846
847
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, DetrDecoder):
            module.gradient_checkpointing = value

NielsRogge's avatar
NielsRogge committed
848
849

DETR_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
850
851
852
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)
NielsRogge's avatar
NielsRogge committed
853

Sylvain Gugger's avatar
Sylvain Gugger committed
854
855
856
    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
NielsRogge's avatar
NielsRogge committed
857
858

    Parameters:
859
        config ([`DetrConfig`]):
NielsRogge's avatar
NielsRogge committed
860
861
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
862
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
NielsRogge's avatar
NielsRogge committed
863
864
865
866
"""

DETR_INPUTS_DOCSTRING = r"""
    Args:
867
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
NielsRogge's avatar
NielsRogge committed
868
869
            Pixel values. Padding will be ignored by default should you provide it.

Sylvain Gugger's avatar
Sylvain Gugger committed
870
            Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.
NielsRogge's avatar
NielsRogge committed
871

872
873
        pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
NielsRogge's avatar
NielsRogge committed
874
875
876
877

            - 1 for pixels that are real (i.e. **not masked**),
            - 0 for pixels that are padding (i.e. **masked**).

878
            [What are attention masks?](../glossary#attention-mask)
NielsRogge's avatar
NielsRogge committed
879

880
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*):
NielsRogge's avatar
NielsRogge committed
881
            Not used by default. Can be used to mask object queries.
882
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
883
884
885
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
886
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
887
888
            Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
            can choose to directly pass a flattened representation of an image.
889
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
890
891
            Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
            embedded representation.
892
893
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
NielsRogge's avatar
NielsRogge committed
894
            tensors for more detail.
895
896
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
NielsRogge's avatar
NielsRogge committed
897
            more detail.
898
        return_dict (`bool`, *optional*):
899
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
NielsRogge's avatar
NielsRogge committed
900
901
902
903
904
905
"""


class DetrEncoder(DetrPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
906
    [`DetrEncoderLayer`].
NielsRogge's avatar
NielsRogge committed
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925

    The encoder updates the flattened feature map through multiple self-attention layers.

    Small tweak for DETR:

    - position_embeddings are added to the forward pass.

    Args:
        config: DetrConfig
    """

    def __init__(self, config: DetrConfig):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])

NielsRogge's avatar
NielsRogge committed
926
        # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
NielsRogge's avatar
NielsRogge committed
927

928
929
        # Initialize weights and apply final processing
        self.post_init()
NielsRogge's avatar
NielsRogge committed
930
931
932
933
934
935
936
937
938
939
940
941

    def forward(
        self,
        inputs_embeds=None,
        attention_mask=None,
        position_embeddings=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Args:
942
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
NielsRogge's avatar
NielsRogge committed
943
944
                Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.

945
946
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
NielsRogge's avatar
NielsRogge committed
947
948
949
950

                - 1 for pixel features that are real (i.e. **not masked**),
                - 0 for pixel features that are padding (i.e. **masked**).

951
                [What are attention masks?](../glossary#attention-mask)
NielsRogge's avatar
NielsRogge committed
952

953
            position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
NielsRogge's avatar
NielsRogge committed
954
955
                Position embeddings that are added to the queries and keys in each self-attention layer.

956
957
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
NielsRogge's avatar
NielsRogge committed
958
                returned tensors for more detail.
959
960
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
NielsRogge's avatar
NielsRogge committed
961
                for more detail.
962
            return_dict (`bool`, *optional*):
963
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
NielsRogge's avatar
NielsRogge committed
964
965
966
967
968
969
970
971
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        hidden_states = inputs_embeds
972
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
NielsRogge's avatar
NielsRogge committed
973
974
975

        # expand attention_mask
        if attention_mask is not None:
NielsRogge's avatar
NielsRogge committed
976
            # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
NielsRogge's avatar
NielsRogge committed
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
            attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for i, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                # we add position_embeddings as extra input to the encoder_layer
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    position_embeddings=position_embeddings,
                    output_attentions=output_attentions,
                )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )


class DetrDecoder(DetrPreTrainedModel):
    """
1014
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
NielsRogge's avatar
NielsRogge committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026

    The decoder updates the query embeddings through multiple self-attention and cross-attention layers.

    Some small tweaks for DETR:

    - position_embeddings and query_position_embeddings are added to the forward pass.
    - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.

    Args:
        config: DetrConfig
    """

NielsRogge's avatar
NielsRogge committed
1027
    def __init__(self, config: DetrConfig):
NielsRogge's avatar
NielsRogge committed
1028
1029
1030
1031
1032
1033
1034
1035
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop

        self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
        # in DETR, the decoder uses layernorm after the last decoder layer output
        self.layernorm = nn.LayerNorm(config.d_model)

1036
        self.gradient_checkpointing = False
1037
1038
        # Initialize weights and apply final processing
        self.post_init()
NielsRogge's avatar
NielsRogge committed
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053

    def forward(
        self,
        inputs_embeds=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        position_embeddings=None,
        query_position_embeddings=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Args:
1054
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
NielsRogge's avatar
NielsRogge committed
1055
1056
                The query embeddings that are passed into the decoder.

1057
1058
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
NielsRogge's avatar
NielsRogge committed
1059
1060
1061
1062

                - 1 for queries that are **not masked**,
                - 0 for queries that are **masked**.

1063
1064
                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
1065
1066
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
1067
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
NielsRogge's avatar
NielsRogge committed
1068
                Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
1069
                in `[0, 1]`:
NielsRogge's avatar
NielsRogge committed
1070
1071
1072
1073

                - 1 for pixels that are real (i.e. **not masked**),
                - 0 for pixels that are padding (i.e. **masked**).

1074
            position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
NielsRogge's avatar
NielsRogge committed
1075
                Position embeddings that are added to the queries and keys in each cross-attention layer.
Sylvain Gugger's avatar
Sylvain Gugger committed
1076
1077
            query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
                , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
1078
1079
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
NielsRogge's avatar
NielsRogge committed
1080
                returned tensors for more detail.
1081
1082
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
NielsRogge's avatar
NielsRogge committed
1083
                for more detail.
1084
            return_dict (`bool`, *optional*):
1085
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
NielsRogge's avatar
NielsRogge committed
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is not None:
            hidden_states = inputs_embeds
            input_shape = inputs_embeds.size()[:-1]

        combined_attention_mask = None

        if attention_mask is not None and combined_attention_mask is not None:
NielsRogge's avatar
NielsRogge committed
1100
            # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
NielsRogge's avatar
NielsRogge committed
1101
            combined_attention_mask = combined_attention_mask + _expand_mask(
NielsRogge's avatar
NielsRogge committed
1102
                attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
NielsRogge's avatar
NielsRogge committed
1103
1104
1105
1106
            )

        # expand encoder attention mask
        if encoder_hidden_states is not None and encoder_attention_mask is not None:
NielsRogge's avatar
NielsRogge committed
1107
1108
1109
1110
            # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
            encoder_attention_mask = _expand_mask(
                encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
            )
NielsRogge's avatar
NielsRogge committed
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127

        # optional intermediate hidden states
        intermediate = () if self.config.auxiliary_loss else None

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        for idx, decoder_layer in enumerate(self.layers):
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):
                continue

1128
            if self.gradient_checkpointing and self.training:
NielsRogge's avatar
NielsRogge committed
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    combined_attention_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=combined_attention_mask,
                    position_embeddings=position_embeddings,
                    query_position_embeddings=query_position_embeddings,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

            if self.config.auxiliary_loss:
                hidden_states = self.layernorm(hidden_states)
                intermediate += (hidden_states,)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # finally, apply layernorm
        hidden_states = self.layernorm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # stack intermediate decoder activations
        if self.config.auxiliary_loss:
            intermediate = torch.stack(intermediate)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
                if v is not None
            )
        return DetrDecoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
            intermediate_hidden_states=intermediate,
        )


@add_start_docstrings(
    """
    The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
    any specific head on top.
    """,
    DETR_START_DOCSTRING,
)
class DetrModel(DetrPreTrainedModel):
    def __init__(self, config: DetrConfig):
        super().__init__(config)

        # Create backbone + positional encoding
1205
        backbone = DetrConvEncoder(config)
NielsRogge's avatar
NielsRogge committed
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
        position_embeddings = build_position_encoding(config)
        self.backbone = DetrConvModel(backbone, position_embeddings)

        # Create projection layer
        self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)

        self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)

        self.encoder = DetrEncoder(config)
        self.decoder = DetrDecoder(config)

1217
1218
        # Initialize weights and apply final processing
        self.post_init()
NielsRogge's avatar
NielsRogge committed
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def freeze_backbone(self):
        for name, param in self.backbone.conv_encoder.model.named_parameters():
            param.requires_grad_(False)

    def unfreeze_backbone(self):
        for name, param in self.backbone.conv_encoder.model.named_parameters():
            param.requires_grad_(True)

    @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=DetrModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values,
        pixel_mask=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:

1251
1252
1253
        Examples:

        ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
1254
        >>> from transformers import AutoImageProcessor, DetrModel
1255
1256
1257
        >>> from PIL import Image
        >>> import requests

Sylvain Gugger's avatar
Sylvain Gugger committed
1258
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1259
1260
        >>> image = Image.open(requests.get(url, stream=True).raw)

Sylvain Gugger's avatar
Sylvain Gugger committed
1261
        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
Sylvain Gugger's avatar
Sylvain Gugger committed
1262
        >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
1263
1264

        >>> # prepare image for the model
1265
        >>> inputs = image_processor(images=image, return_tensors="pt")
1266
1267

        >>> # forward pass
1268
        >>> outputs = model(**inputs)
1269
1270
1271

        >>> # the last hidden states are the final query embeddings of the Transformer decoder
        >>> # these are of shape (batch_size, num_queries, hidden_size)
1272
        >>> last_hidden_states = outputs.last_hidden_state
Quentin's avatar
Quentin committed
1273
1274
        >>> list(last_hidden_states.shape)
        [1, 100, 256]
1275
        ```"""
NielsRogge's avatar
NielsRogge committed
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, num_channels, height, width = pixel_values.shape
        device = pixel_values.device

        if pixel_mask is None:
            pixel_mask = torch.ones(((batch_size, height, width)), device=device)

        # First, sent pixel_values + pixel_mask through Backbone to obtain the features
        # pixel_values should be of shape (batch_size, num_channels, height, width)
        # pixel_mask should be of shape (batch_size, height, width)
        features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)

        # get final feature map and downsampled mask
        feature_map, mask = features[-1]

1296
1297
        if mask is None:
            raise ValueError("Backbone does not return downsampled pixel mask")
NielsRogge's avatar
NielsRogge committed
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382

        # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
        projected_feature_map = self.input_projection(feature_map)

        # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
        # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
        flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
        position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)

        flattened_mask = mask.flatten(1)

        # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
        # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
        # flattened_mask is a Tensor of shape (batch_size, heigth*width)
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                inputs_embeds=flattened_features,
                attention_mask=flattened_mask,
                position_embeddings=position_embeddings,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
        query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
        queries = torch.zeros_like(query_position_embeddings)

        # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            inputs_embeds=queries,
            attention_mask=None,
            position_embeddings=position_embeddings,
            query_position_embeddings=query_position_embeddings,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=flattened_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return DetrModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
        )


@add_start_docstrings(
    """
    DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
    such as COCO detection.
    """,
    DETR_START_DOCSTRING,
)
class DetrForObjectDetection(DetrPreTrainedModel):
    def __init__(self, config: DetrConfig):
        super().__init__(config)

        # DETR encoder-decoder model
        self.model = DetrModel(config)

        # Object detection heads
        self.class_labels_classifier = nn.Linear(
            config.d_model, config.num_labels + 1
        )  # We add one for the "no object" class
        self.bbox_predictor = DetrMLPPredictionHead(
            input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
        )

1383
1384
        # Initialize weights and apply final processing
        self.post_init()
NielsRogge's avatar
NielsRogge committed
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409

    # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

    @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=DetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values,
        pixel_mask=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
1410
        labels (`List[Dict]` of len `(batch_size,)`, *optional*):
NielsRogge's avatar
NielsRogge committed
1411
1412
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
Sylvain Gugger's avatar
Sylvain Gugger committed
1413
1414
            respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
            in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
NielsRogge's avatar
NielsRogge committed
1415
1416
1417

        Returns:

1418
        Examples:
NielsRogge's avatar
NielsRogge committed
1419

1420
        ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
1421
        >>> from transformers import AutoImageProcessor, DetrForObjectDetection
1422
        >>> import torch
1423
1424
        >>> from PIL import Image
        >>> import requests
NielsRogge's avatar
NielsRogge committed
1425

Sylvain Gugger's avatar
Sylvain Gugger committed
1426
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1427
        >>> image = Image.open(requests.get(url, stream=True).raw)
NielsRogge's avatar
NielsRogge committed
1428

Sylvain Gugger's avatar
Sylvain Gugger committed
1429
        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
Sylvain Gugger's avatar
Sylvain Gugger committed
1430
        >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
NielsRogge's avatar
NielsRogge committed
1431

1432
        >>> inputs = image_processor(images=image, return_tensors="pt")
1433
        >>> outputs = model(**inputs)
Quentin's avatar
Quentin committed
1434

1435
1436
        >>> # convert outputs (bounding boxes and class logits) to COCO API
        >>> target_sizes = torch.tensor([image.size[::-1]])
1437
1438
1439
        >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
        ...     0
        ... ]
1440
1441
1442

        >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        ...     box = [round(i, 2) for i in box.tolist()]
1443
1444
1445
1446
        ...     print(
        ...         f"Detected {model.config.id2label[label.item()]} with confidence "
        ...         f"{round(score.item(), 3)} at location {box}"
        ...     )
1447
1448
1449
1450
1451
        Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
        Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
        Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
        Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
        Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
1452
        ```"""
NielsRogge's avatar
NielsRogge committed
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # First, sent images through DETR base model to obtain encoder + decoder outputs
        outputs = self.model(
            pixel_values,
            pixel_mask=pixel_mask,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # class logits + predicted bounding boxes
        logits = self.class_labels_classifier(sequence_output)
        pred_boxes = self.bbox_predictor(sequence_output).sigmoid()

        loss, loss_dict, auxiliary_outputs = None, None, None
        if labels is not None:
            # First: create the matcher
            matcher = DetrHungarianMatcher(
                class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
            )
            # Second: create the criterion
            losses = ["labels", "boxes", "cardinality"]
            criterion = DetrLoss(
                matcher=matcher,
                num_classes=self.config.num_labels,
                eos_coef=self.config.eos_coefficient,
                losses=losses,
            )
            criterion.to(self.device)
            # Third: compute the losses, based on outputs and labels
            outputs_loss = {}
            outputs_loss["logits"] = logits
            outputs_loss["pred_boxes"] = pred_boxes
            if self.config.auxiliary_loss:
                intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
                outputs_class = self.class_labels_classifier(intermediate)
                outputs_coord = self.bbox_predictor(intermediate).sigmoid()
                auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
                outputs_loss["auxiliary_outputs"] = auxiliary_outputs

            loss_dict = criterion(outputs_loss, labels)
            # Fourth: compute total loss, as a weighted sum of the various losses
            weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
            weight_dict["loss_giou"] = self.config.giou_loss_coefficient
            if self.config.auxiliary_loss:
                aux_weight_dict = {}
                for i in range(self.config.decoder_layers - 1):
                    aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
                weight_dict.update(aux_weight_dict)
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        if not return_dict:
            if auxiliary_outputs is not None:
                output = (logits, pred_boxes) + auxiliary_outputs + outputs
            else:
                output = (logits, pred_boxes) + outputs
            return ((loss, loss_dict) + output) if loss is not None else output

        return DetrObjectDetectionOutput(
            loss=loss,
            loss_dict=loss_dict,
            logits=logits,
            pred_boxes=pred_boxes,
            auxiliary_outputs=auxiliary_outputs,
            last_hidden_state=outputs.last_hidden_state,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )


@add_start_docstrings(
    """
    DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
    such as COCO panoptic.

    """,
    DETR_START_DOCSTRING,
)
class DetrForSegmentation(DetrPreTrainedModel):
    def __init__(self, config: DetrConfig):
        super().__init__(config)

        # object detection model
        self.detr = DetrForObjectDetection(config)

        # segmentation head
        hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
        intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes

        self.mask_head = DetrMaskHeadSmallConv(
            hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
        )

        self.bbox_attention = DetrMHAttentionMap(
            hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
        )

1561
1562
        # Initialize weights and apply final processing
        self.post_init()
NielsRogge's avatar
NielsRogge committed
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579

    @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=DetrSegmentationOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values,
        pixel_mask=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
1580
        labels (`List[Dict]` of len `(batch_size,)`, *optional*):
NielsRogge's avatar
NielsRogge committed
1581
1582
1583
            Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
            dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
            bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
1584
1585
1586
            should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
            `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
            `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
NielsRogge's avatar
NielsRogge committed
1587
1588
1589

        Returns:

1590
        Examples:
NielsRogge's avatar
NielsRogge committed
1591

1592
        ```python
1593
        >>> import io
1594
        >>> import requests
1595
1596
1597
1598
        >>> from PIL import Image
        >>> import torch
        >>> import numpy

Sylvain Gugger's avatar
Sylvain Gugger committed
1599
        >>> from transformers import AutoImageProcessor, DetrForSegmentation
1600
        >>> from transformers.image_transforms import rgb_to_id
NielsRogge's avatar
NielsRogge committed
1601

Sylvain Gugger's avatar
Sylvain Gugger committed
1602
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1603
        >>> image = Image.open(requests.get(url, stream=True).raw)
NielsRogge's avatar
NielsRogge committed
1604

Sylvain Gugger's avatar
Sylvain Gugger committed
1605
        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
Sylvain Gugger's avatar
Sylvain Gugger committed
1606
        >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
NielsRogge's avatar
NielsRogge committed
1607

1608
        >>> # prepare image for the model
1609
        >>> inputs = image_processor(images=image, return_tensors="pt")
1610
1611

        >>> # forward pass
1612
        >>> outputs = model(**inputs)
1613

Sylvain Gugger's avatar
Sylvain Gugger committed
1614
        >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
1615
        >>> # Segmentation results are returned as a list of dictionaries
1616
        >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
1617

1618
        >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
1619
        >>> panoptic_seg = result[0]["segmentation"]
1620
        >>> # Get prediction score and segment_id to class_id mapping of each segment
1621
        >>> panoptic_segments_info = result[0]["segments_info"]
1622
        ```"""
NielsRogge's avatar
NielsRogge committed
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, num_channels, height, width = pixel_values.shape
        device = pixel_values.device

        if pixel_mask is None:
            pixel_mask = torch.ones((batch_size, height, width), device=device)

        # First, get list of feature maps and position embeddings
        features, position_embeddings_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)

        # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
        feature_map, mask = features[-1]
        batch_size, num_channels, height, width = feature_map.shape
        projected_feature_map = self.detr.model.input_projection(feature_map)

        # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
        # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
        flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
        position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)

        flattened_mask = mask.flatten(1)

        # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
        # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
        # flattened_mask is a Tensor of shape (batch_size, heigth*width)
        if encoder_outputs is None:
            encoder_outputs = self.detr.model.encoder(
                inputs_embeds=flattened_features,
                attention_mask=flattened_mask,
                position_embeddings=position_embeddings,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
        query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
            batch_size, 1, 1
        )
        queries = torch.zeros_like(query_position_embeddings)

        # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
        decoder_outputs = self.detr.model.decoder(
            inputs_embeds=queries,
            attention_mask=None,
            position_embeddings=position_embeddings,
            query_position_embeddings=query_position_embeddings,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=flattened_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Sixth, compute logits, pred_boxes and pred_masks
        logits = self.detr.class_labels_classifier(sequence_output)
        pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()

        memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
        mask = flattened_mask.view(batch_size, height, width)

        # FIXME h_boxes takes the last one computed, keep this in mind
        # important: we need to reverse the mask, since in the original implementation the mask works reversed
        # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
        bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)

        seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])

        pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])

        loss, loss_dict, auxiliary_outputs = None, None, None
        if labels is not None:
            # First: create the matcher
            matcher = DetrHungarianMatcher(
                class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
            )
            # Second: create the criterion
            losses = ["labels", "boxes", "cardinality", "masks"]
            criterion = DetrLoss(
                matcher=matcher,
                num_classes=self.config.num_labels,
                eos_coef=self.config.eos_coefficient,
                losses=losses,
            )
            criterion.to(self.device)
            # Third: compute the losses, based on outputs and labels
            outputs_loss = {}
            outputs_loss["logits"] = logits
            outputs_loss["pred_boxes"] = pred_boxes
            outputs_loss["pred_masks"] = pred_masks
            if self.config.auxiliary_loss:
                intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
                outputs_class = self.class_labels_classifier(intermediate)
                outputs_coord = self.bbox_predictor(intermediate).sigmoid()
                auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
                outputs_loss["auxiliary_outputs"] = auxiliary_outputs

            loss_dict = criterion(outputs_loss, labels)
            # Fourth: compute total loss, as a weighted sum of the various losses
            weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
            weight_dict["loss_giou"] = self.config.giou_loss_coefficient
            weight_dict["loss_mask"] = self.config.mask_loss_coefficient
            weight_dict["loss_dice"] = self.config.dice_loss_coefficient
            if self.config.auxiliary_loss:
                aux_weight_dict = {}
                for i in range(self.config.decoder_layers - 1):
                    aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
                weight_dict.update(aux_weight_dict)
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        if not return_dict:
            if auxiliary_outputs is not None:
                output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
            else:
                output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
            return ((loss, loss_dict) + output) if loss is not None else output

        return DetrSegmentationOutput(
            loss=loss,
            loss_dict=loss_dict,
            logits=logits,
            pred_boxes=pred_boxes,
            pred_masks=pred_masks,
            auxiliary_outputs=auxiliary_outputs,
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


def _expand(tensor, length: int):
    return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)


# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
class DetrMaskHeadSmallConv(nn.Module):
    """
    Simple convolutional head, using group norm. Upsampling is done using a FPN approach
    """

    def __init__(self, dim, fpn_dims, context_dim):
        super().__init__()

1781
1782
        if dim % 8 != 0:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1783
1784
                "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
                " GroupNorm is set to 8"
1785
            )
NielsRogge's avatar
NielsRogge committed
1786
1787
1788

        inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]

1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
        self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = nn.GroupNorm(8, dim)
        self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
        self.gn2 = nn.GroupNorm(8, inter_dims[1])
        self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
        self.gn3 = nn.GroupNorm(8, inter_dims[2])
        self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
        self.gn4 = nn.GroupNorm(8, inter_dims[3])
        self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
        self.gn5 = nn.GroupNorm(8, inter_dims[4])
        self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
NielsRogge's avatar
NielsRogge committed
1800
1801
1802

        self.dim = dim

1803
1804
1805
        self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
        self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
        self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
NielsRogge's avatar
NielsRogge committed
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
        # here we concatenate x, the projected feature map, of shape (batch_size, d_model, heigth/32, width/32) with
        # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
        # We expand the projected feature map to match the number of heads.
        x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)

        x = self.lay1(x)
        x = self.gn1(x)
1820
        x = nn.functional.relu(x)
NielsRogge's avatar
NielsRogge committed
1821
1822
        x = self.lay2(x)
        x = self.gn2(x)
1823
        x = nn.functional.relu(x)
NielsRogge's avatar
NielsRogge committed
1824
1825
1826
1827

        cur_fpn = self.adapter1(fpns[0])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1828
        x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
NielsRogge's avatar
NielsRogge committed
1829
1830
        x = self.lay3(x)
        x = self.gn3(x)
1831
        x = nn.functional.relu(x)
NielsRogge's avatar
NielsRogge committed
1832
1833
1834
1835

        cur_fpn = self.adapter2(fpns[1])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1836
        x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
NielsRogge's avatar
NielsRogge committed
1837
1838
        x = self.lay4(x)
        x = self.gn4(x)
1839
        x = nn.functional.relu(x)
NielsRogge's avatar
NielsRogge committed
1840
1841
1842
1843

        cur_fpn = self.adapter3(fpns[2])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1844
        x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
NielsRogge's avatar
NielsRogge committed
1845
1846
        x = self.lay5(x)
        x = self.gn5(x)
1847
        x = nn.functional.relu(x)
NielsRogge's avatar
NielsRogge committed
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868

        x = self.out_lay(x)
        return x


class DetrMHAttentionMap(nn.Module):
    """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""

    def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)

        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k, mask: Optional[Tensor] = None):
        q = self.q_linear(q)
1869
        k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
NielsRogge's avatar
NielsRogge committed
1870
1871
1872
1873
1874
        queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)

        if mask is not None:
Yih-Dar's avatar
Yih-Dar committed
1875
            weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1876
        weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
NielsRogge's avatar
NielsRogge committed
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
        weights = self.dropout(weights)
        return weights


def dice_loss(inputs, targets, num_boxes):
    """
    Compute the DICE loss, similar to generalized IOU for masks

    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs (0 for the negative class and 1 for the positive
                 class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_boxes


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
NielsRogge's avatar
NielsRogge committed
1905
1906
1907
1908
1909
1910
1911
1912
1913
        inputs (`torch.FloatTensor` of arbitrary shape):
            The predictions for each example.
        targets (`torch.FloatTensor` with the same shape as `inputs`)
            A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
            and 1 for the positive class).
        alpha (`float`, *optional*, defaults to `0.25`):
            Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
        gamma (`int`, *optional*, defaults to `2`):
            Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
NielsRogge's avatar
NielsRogge committed
1914
1915
1916
1917
1918

    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
1919
    ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
NielsRogge's avatar
NielsRogge committed
1920
    # add modulating factor
NielsRogge's avatar
NielsRogge committed
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes


# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class DetrLoss(nn.Module):
    """
    This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
    we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
NielsRogge's avatar
NielsRogge committed
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
    of matched ground-truth / prediction (supervise class and box).

    A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
    parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
    the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
    be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
    (`max_obj_id` + 1). For more details on this, check the following discussion
    https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"


    Args:
        matcher (`DetrHungarianMatcher`):
            Module able to compute a matching between targets and proposals.
        num_classes (`int`):
            Number of object categories, omitting the special no-object category.
        eos_coef (`float`):
            Relative classification weight applied to the no-object category.
        losses (`List[str]`):
            List of all the losses to be applied. See `get_loss` for a list of all available losses.
NielsRogge's avatar
NielsRogge committed
1955
1956
1957
1958
1959
    """

    def __init__(self, matcher, num_classes, eos_coef, losses):
        super().__init__()
        self.matcher = matcher
NielsRogge's avatar
NielsRogge committed
1960
        self.num_classes = num_classes
NielsRogge's avatar
NielsRogge committed
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)

    # removed logging parameter, which was part of the original implementation
    def loss_labels(self, outputs, targets, indices, num_boxes):
        """
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        """
1973
1974
        if "logits" not in outputs:
            raise KeyError("No logits were found in the outputs")
NielsRogge's avatar
NielsRogge committed
1975
        source_logits = outputs["logits"]
NielsRogge's avatar
NielsRogge committed
1976

NielsRogge's avatar
NielsRogge committed
1977
        idx = self._get_source_permutation_idx(indices)
NielsRogge's avatar
NielsRogge committed
1978
1979
        target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
NielsRogge's avatar
NielsRogge committed
1980
            source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
NielsRogge's avatar
NielsRogge committed
1981
1982
1983
        )
        target_classes[idx] = target_classes_o

NielsRogge's avatar
NielsRogge committed
1984
        loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
NielsRogge's avatar
NielsRogge committed
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
        losses = {"loss_ce": loss_ce}

        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """
        Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.

        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
        """
        logits = outputs["logits"]
        device = logits.device
NielsRogge's avatar
NielsRogge committed
1998
        target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
NielsRogge's avatar
NielsRogge committed
1999
2000
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
NielsRogge's avatar
NielsRogge committed
2001
        card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
NielsRogge's avatar
NielsRogge committed
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
        losses = {"cardinality_error": card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.

        Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
        are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
2012
2013
        if "pred_boxes" not in outputs:
            raise KeyError("No predicted boxes found in outputs")
NielsRogge's avatar
NielsRogge committed
2014
2015
        idx = self._get_source_permutation_idx(indices)
        source_boxes = outputs["pred_boxes"][idx]
NielsRogge's avatar
NielsRogge committed
2016
2017
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

NielsRogge's avatar
NielsRogge committed
2018
        loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
NielsRogge's avatar
NielsRogge committed
2019
2020
2021
2022
2023

        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(
NielsRogge's avatar
NielsRogge committed
2024
            generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
NielsRogge's avatar
NielsRogge committed
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
        )
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses

    def loss_masks(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.

        Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
        """
2035
2036
        if "pred_masks" not in outputs:
            raise KeyError("No predicted masks found in outputs")
NielsRogge's avatar
NielsRogge committed
2037

NielsRogge's avatar
NielsRogge committed
2038
2039
2040
2041
        source_idx = self._get_source_permutation_idx(indices)
        target_idx = self._get_target_permutation_idx(indices)
        source_masks = outputs["pred_masks"]
        source_masks = source_masks[source_idx]
NielsRogge's avatar
NielsRogge committed
2042
2043
2044
        masks = [t["masks"] for t in targets]
        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
NielsRogge's avatar
NielsRogge committed
2045
2046
        target_masks = target_masks.to(source_masks)
        target_masks = target_masks[target_idx]
NielsRogge's avatar
NielsRogge committed
2047
2048

        # upsample predictions to the target size
NielsRogge's avatar
NielsRogge committed
2049
2050
        source_masks = nn.functional.interpolate(
            source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
NielsRogge's avatar
NielsRogge committed
2051
        )
NielsRogge's avatar
NielsRogge committed
2052
        source_masks = source_masks[:, 0].flatten(1)
NielsRogge's avatar
NielsRogge committed
2053
2054

        target_masks = target_masks.flatten(1)
NielsRogge's avatar
NielsRogge committed
2055
        target_masks = target_masks.view(source_masks.shape)
NielsRogge's avatar
NielsRogge committed
2056
        losses = {
NielsRogge's avatar
NielsRogge committed
2057
2058
            "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
NielsRogge's avatar
NielsRogge committed
2059
2060
2061
        }
        return losses

NielsRogge's avatar
NielsRogge committed
2062
    def _get_source_permutation_idx(self, indices):
NielsRogge's avatar
NielsRogge committed
2063
        # permute predictions following indices
NielsRogge's avatar
NielsRogge committed
2064
2065
2066
        batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
        source_idx = torch.cat([source for (source, _) in indices])
        return batch_idx, source_idx
NielsRogge's avatar
NielsRogge committed
2067

NielsRogge's avatar
NielsRogge committed
2068
    def _get_target_permutation_idx(self, indices):
NielsRogge's avatar
NielsRogge committed
2069
        # permute targets following indices
NielsRogge's avatar
NielsRogge committed
2070
2071
2072
        batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
        target_idx = torch.cat([target for (_, target) in indices])
        return batch_idx, target_idx
NielsRogge's avatar
NielsRogge committed
2073
2074
2075
2076
2077
2078
2079
2080

    def get_loss(self, loss, outputs, targets, indices, num_boxes):
        loss_map = {
            "labels": self.loss_labels,
            "cardinality": self.loss_cardinality,
            "boxes": self.loss_boxes,
            "masks": self.loss_masks,
        }
2081
2082
        if loss not in loss_map:
            raise ValueError(f"Loss {loss} not supported")
NielsRogge's avatar
NielsRogge committed
2083
2084
2085
2086
2087
2088
        return loss_map[loss](outputs, targets, indices, num_boxes)

    def forward(self, outputs, targets):
        """
        This performs the loss computation.

NielsRogge's avatar
NielsRogge committed
2089
2090
2091
2092
        Args:
             outputs (`dict`, *optional*):
                Dictionary of tensors, see the output specification of the model for the format.
             targets (`List[dict]`, *optional*):
NielsRogge's avatar
NielsRogge committed
2093
                List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
NielsRogge's avatar
NielsRogge committed
2094
                losses applied, see each loss' doc.
NielsRogge's avatar
NielsRogge committed
2095
2096
2097
2098
2099
2100
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

Yulv-git's avatar
Yulv-git committed
2101
        # Compute the average number of target boxes across all nodes, for normalization purposes
NielsRogge's avatar
NielsRogge committed
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
        num_boxes = sum(len(t["class_labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        # (Niels): comment out function below, distributed training to be added
        # if is_dist_avail_and_initialized():
        #     torch.distributed.all_reduce(num_boxes)
        # (Niels) in original implementation, num_boxes is divided by get_world_size()
        num_boxes = torch.clamp(num_boxes, min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if "auxiliary_outputs" in outputs:
            for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
                indices = self.matcher(auxiliary_outputs, targets)
                for loss in self.losses:
                    if loss == "masks":
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses


# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class DetrMLPPredictionHead(nn.Module):
    """
    Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
    height and width of a bounding box w.r.t. an image.

    Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py

    """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
2148
            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
NielsRogge's avatar
NielsRogge committed
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
        return x


# taken from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
class DetrHungarianMatcher(nn.Module):
    """
    This class computes an assignment between the targets and the predictions of the network.

    For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
    predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
    un-matched (and thus treated as non-objects).
NielsRogge's avatar
NielsRogge committed
2160
2161
2162
2163
2164
2165
2166
2167

    Args:
        class_cost:
            The relative weight of the classification error in the matching cost.
        bbox_cost:
            The relative weight of the L1 error of the bounding box coordinates in the matching cost.
        giou_cost:
            The relative weight of the giou loss of the bounding box in the matching cost.
NielsRogge's avatar
NielsRogge committed
2168
2169
2170
2171
2172
2173
2174
2175
2176
    """

    def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
        super().__init__()
        requires_backends(self, ["scipy"])

        self.class_cost = class_cost
        self.bbox_cost = bbox_cost
        self.giou_cost = giou_cost
2177
        if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
2178
            raise ValueError("All costs of the Matcher can't be 0")
NielsRogge's avatar
NielsRogge committed
2179
2180
2181
2182

    @torch.no_grad()
    def forward(self, outputs, targets):
        """
NielsRogge's avatar
NielsRogge committed
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
        Args:
            outputs (`dict`):
                A dictionary that contains at least these entries:
                * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
            targets (`List[dict]`):
                A list of targets (len(targets) = batch_size), where each target is a dict containing:
                * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
                  ground-truth
                 objects in the target) containing the class labels
                * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
NielsRogge's avatar
NielsRogge committed
2194
2195

        Returns:
NielsRogge's avatar
NielsRogge committed
2196
2197
2198
            `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
NielsRogge's avatar
NielsRogge committed
2199
2200
            For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
NielsRogge's avatar
NielsRogge committed
2201
        batch_size, num_queries = outputs["logits"].shape[:2]
NielsRogge's avatar
NielsRogge committed
2202
2203
2204
2205
2206
2207

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
NielsRogge's avatar
NielsRogge committed
2208
2209
        target_ids = torch.cat([v["class_labels"] for v in targets])
        target_bbox = torch.cat([v["boxes"] for v in targets])
NielsRogge's avatar
NielsRogge committed
2210
2211
2212
2213

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
NielsRogge's avatar
NielsRogge committed
2214
        class_cost = -out_prob[:, target_ids]
NielsRogge's avatar
NielsRogge committed
2215
2216

        # Compute the L1 cost between boxes
NielsRogge's avatar
NielsRogge committed
2217
        bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
NielsRogge's avatar
NielsRogge committed
2218
2219

        # Compute the giou cost between boxes
NielsRogge's avatar
NielsRogge committed
2220
        giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
NielsRogge's avatar
NielsRogge committed
2221
2222
2223

        # Final cost matrix
        cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
NielsRogge's avatar
NielsRogge committed
2224
        cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
NielsRogge's avatar
NielsRogge committed
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246

        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


# below: bounding box utilities taken from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py


def _upcast(t: Tensor) -> Tensor:
    # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
    if t.is_floating_point():
        return t if t.dtype in (torch.float32, torch.float64) else t.float()
    else:
        return t if t.dtype in (torch.int32, torch.int64) else t.int()


def box_area(boxes: Tensor) -> Tensor:
    """
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
NielsRogge's avatar
NielsRogge committed
2247
2248
2249
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.
NielsRogge's avatar
NielsRogge committed
2250
2251

    Returns:
NielsRogge's avatar
NielsRogge committed
2252
        `torch.FloatTensor`: a tensor containing the area for each box.
NielsRogge's avatar
NielsRogge committed
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
    """
    boxes = _upcast(boxes)
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

NielsRogge's avatar
NielsRogge committed
2263
2264
    left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
NielsRogge's avatar
NielsRogge committed
2265

NielsRogge's avatar
NielsRogge committed
2266
2267
    width_height = (right_bottom - left_top).clamp(min=0)  # [N,M,2]
    inter = width_height[:, :, 0] * width_height[:, :, 1]  # [N,M]
NielsRogge's avatar
NielsRogge committed
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279

    union = area1[:, None] + area2 - inter

    iou = inter / union
    return iou, union


def generalized_box_iou(boxes1, boxes2):
    """
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
NielsRogge's avatar
NielsRogge committed
2280
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
NielsRogge's avatar
NielsRogge committed
2281
2282
2283
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
NielsRogge's avatar
NielsRogge committed
2284
2285
2286
2287
    if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
        raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
    if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
        raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
NielsRogge's avatar
NielsRogge committed
2288
2289
    iou, union = box_iou(boxes1, boxes2)

NielsRogge's avatar
NielsRogge committed
2290
2291
    top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
    bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
NielsRogge's avatar
NielsRogge committed
2292

NielsRogge's avatar
NielsRogge committed
2293
2294
    width_height = (bottom_right - top_left).clamp(min=0)  # [N,M,2]
    area = width_height[:, :, 0] * width_height[:, :, 1]
NielsRogge's avatar
NielsRogge committed
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333

    return iou - (area - union) / area


# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    if tensor_list[0].ndim == 3:
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        batch_shape = [len(tensor_list)] + max_size
NielsRogge's avatar
NielsRogge committed
2334
        batch_size, num_channels, height, width = batch_shape
NielsRogge's avatar
NielsRogge committed
2335
2336
2337
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
NielsRogge's avatar
NielsRogge committed
2338
        mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
NielsRogge's avatar
NielsRogge committed
2339
2340
2341
2342
2343
2344
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], : img.shape[2]] = False
    else:
        raise ValueError("Only 3-dimensional tensors are supported")
    return NestedTensor(tensor, mask)