post_process.py 5.05 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops


class PostProcess(nn.Module):
    """This module converts the model's output into the format expected by the coco api"""
    def __init__(
        self,
        select_box_nums_for_evaluation=100,
        nms_iou_threshold=-1,
        confidence_score=-1,
    ):
        super().__init__()
        self.select_box_nums_for_evaluation = select_box_nums_for_evaluation
        self.nms_iou_threshold = nms_iou_threshold
        self.confidence_score = confidence_score

    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = out_logits.sigmoid()
        topk_values, topk_indexes = torch.topk(
            prob.view(out_logits.shape[0], -1),
            self.select_box_nums_for_evaluation,
            dim=1,
        )
        scores = topk_values
        topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="trunc")
        labels = topk_indexes % out_logits.shape[2]
        boxes = box_ops._box_cxcywh_to_xyxy(out_bbox)
        boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))

        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct[:, None, :]

        item_indice = None
        # filter low-confidence predictions
        if self.confidence_score > 0:
            item_indice = [score > self.confidence_score for score in scores]

        # filter overlap predictions
        if self.nms_iou_threshold > 0:
            nms_indice = [
                box_ops.nms(box, score, iou_threshold=self.nms_iou_threshold)
                for box, score in zip(boxes, scores)
            ]
            nms_binary_indice = [torch.zeros_like(item_index, dtype=torch.bool) for item_index in item_indice]
            for nms_binary_index, nms_index in zip(nms_binary_indice, nms_indice):
                nms_binary_index[nms_index] = True
            item_indice = [
                item_index & nms_binary_index
                for item_index, nms_binary_index in zip(item_indice, nms_binary_indice)
            ]

        if item_indice is not None:
            scores = [score[item_index] for score, item_index in zip(scores, item_indice)]
            boxes = [box[item_index] for box, item_index in zip(boxes, item_indice)]
            labels = [label[item_index] for label, item_index in zip(labels, item_indice)]

        if torchvision._is_tracing():
            # avoid interation warning during ONNX export
            scores, labels, boxes = map(lambda x: x.unbind(0), (scores, labels, boxes))
        results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]

        return results


class SegmentationPostProcess(nn.Module):
    @torch.no_grad()
    def forward(self, outputs, target_sizes, input_sizes, batched_input_size):
        out_logits, out_bbox, out_mask = (
            outputs["pred_logits"],
            outputs["pred_boxes"],
            outputs["pred_masks"],
        )

        assert len(out_logits) == len(target_sizes)
        assert len(batched_input_size) == 2

        # we average queries of the same class to get onehot segmentation image
        out_class = out_logits.argmax(-1)
        num_class = out_logits.shape[-1]
        result_masks = []
        for image_id in range(len(out_logits)):
            result_masks_per_image = []
            for cur_class in range(num_class):
                class_index = out_class[image_id] == cur_class
                mask_per_class = out_mask[image_id][class_index].sigmoid()
                if mask_per_class.numel() == 0:
                    mask_per_class = mask_per_class.new_zeros((1, *mask_per_class.shape[-2:]))
                mask_per_class = mask_per_class.mean(0)
                result_masks_per_image.append(mask_per_class)
            result_masks_per_image = torch.stack(result_masks_per_image, 0)
            result_masks.append(result_masks_per_image)
        result_masks = torch.stack(result_masks, 0)

        # upsample masks with 1/4 resolution to input image shapes
        result_masks = F.interpolate(
            result_masks,
            size=batched_input_size,
            mode="bilinear",
            align_corners=False,
        )

        # resize masks to original shapes and transform onehot into class
        mask_results = []
        for mask, (height, width), (out_height, out_width) in zip(
            result_masks,
            input_sizes,
            target_sizes,
        ):
            mask = F.interpolate(
                mask[None, :, :height, :width],
                size=(out_height, out_width),
                mode="bilinear",
                align_corners=False,
            )[0]
            mask_results.append({"masks": mask.argmax(0)})

        return mask_results