base_transformer.py 5.1 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torchvision
from torch import nn


class DETRBaseTransformer(nn.Module):
    """A base class that contains some methods commonly used in DETR transformer,
    such as DeformableTransformer, DabTransformer, DINOTransformer, AlignTransformer.

    """
    def __init__(self, num_feature_levels, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_feature_levels = num_feature_levels
        self.level_embeds = nn.Parameter(torch.Tensor(num_feature_levels, embed_dim))
        self._init_weights_detr_transformer()

    def _init_weights_detr_transformer(self):
        nn.init.normal_(self.level_embeds)

    @staticmethod
    def flatten_multi_level(multi_level_elements):
        multi_level_elements = torch.cat([e.flatten(-2) for e in multi_level_elements], -1)  # (b, [c], s)
        if multi_level_elements.ndim == 3:
            multi_level_elements.transpose_(1, 2)
        return multi_level_elements

    def get_lvl_pos_embed(self, multi_level_pos_embeds):
        multi_level_pos_embeds = [
            p + l.view(1, -1, 1, 1) for p, l in zip(multi_level_pos_embeds, self.level_embeds)
        ]
        return self.flatten_multi_level(multi_level_pos_embeds)

    def multi_level_misc(self, multi_level_masks):
        if torchvision._is_tracing():
            # torch.Tensor.shape exports not well for ONNX
            # use operators.shape_as_tensor istead
            from torch.onnx import operators
            spatial_shapes = [operators.shape_as_tensor(m)[-2:] for m in multi_level_masks]
            spatial_shapes = torch.stack(spatial_shapes).to(multi_level_masks[0].device)
        else:
            spatial_shapes = [m.shape[-2:] for m in multi_level_masks]
            spatial_shapes = torch.as_tensor(spatial_shapes, device=multi_level_masks[0].device)
        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
        valid_ratios = torch.stack([self.get_valid_ratios(m) for m in multi_level_masks], 1)
        return spatial_shapes, level_start_index, valid_ratios

    @staticmethod
    def get_valid_ratios(mask):
        _, h, w = mask.shape
        valid_h = torch.sum(~mask[:, :, 0], 1)
        valid_w = torch.sum(~mask[:, 0, :], 1)
        valid_ratio_h = valid_h.float() / h
        valid_ratio_w = valid_w.float() / w
        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)  # [n, 2]
        return valid_ratio


class TwostageTransformer(DETRBaseTransformer):
    """A base class that contains some methods commonly used in two-stage transformer,
    such as DeformableTransformer, DabTransformer, DINOTransformer, AlignTransformer.

    """
    def __init__(self, num_feature_levels, embed_dim):
        super().__init__(num_feature_levels, embed_dim)
        self.enc_output = nn.Linear(embed_dim, embed_dim)
        self.enc_output_norm = nn.LayerNorm(embed_dim)
        self._init_weights_two_stage_transformer()

    def _init_weights_two_stage_transformer(self):
        nn.init.xavier_uniform_(self.enc_output.weight)
        nn.init.constant_(self.enc_output.bias, 0.0)

    def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
        n, s, c = memory.shape
        proposals = []
        cur = 0
        if torchvision._is_tracing():
            # avoid iteration warning on torch.Tensor
            # convert Tensor to list[Tensor] instead
            spatial_shapes = [b.unbind(0) for b in spatial_shapes.unbind(0)]
        else:
            # use list to avoid small kernel launching when indexing spatial shapes
            spatial_shapes = spatial_shapes.tolist()

        for lvl, (h, w) in enumerate(spatial_shapes):
            mask_flatten = memory_padding_mask[:, cur:(cur + h * w)].view(n, h, w, 1)
            valid_h = torch.sum(~mask_flatten[:, :, 0, 0], 1)
            valid_w = torch.sum(~mask_flatten[:, 0, :, 0], 1)

            grid_y, grid_x = torch.meshgrid(
                torch.linspace(0, h - 1, h, dtype=torch.float32, device=memory.device),
                torch.linspace(0, w - 1, w, dtype=torch.float32, device=memory.device),
                indexing="ij",
            )
            grid = torch.stack([grid_x, grid_y], -1)  # [h, w, 2]
            scale = torch.stack([valid_w, valid_h], -1).view(n, 1, 1, 2)
            grid = (grid.expand(n, -1, -1, -1) + 0.5) / scale  # [n, h, w, 2]
            wh = torch.ones_like(grid) * 0.05 * 2.0**lvl
            proposal = torch.cat([grid, wh], -1).view(n, -1, 4)
            proposals.append(proposal)
            cur += h * w
        output_proposals = torch.cat(proposals, 1)
        output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
        output_proposals = torch.log(output_proposals / (1 - output_proposals))  # inverse_sigmoid
        output_proposals.masked_fill_(
            memory_padding_mask.unsqueeze(-1) | ~output_proposals_valid, float("inf")
        )

        output_memory = memory * (~memory_padding_mask.unsqueeze(-1)) * (output_proposals_valid)
        output_memory = self.enc_output_norm(self.enc_output(output_memory))
        return output_memory, output_proposals