dsr_head.py 5.77 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# encoding: utf-8
"""
@author:  lingxiao he
@contact: helingxiao3@jd.com
"""

import torch
import torch.nn.functional as F
from torch import nn

from fastreid.layers import *
from fastreid.modeling.heads import EmbeddingHead
from fastreid.modeling.heads.build import REID_HEADS_REGISTRY
from fastreid.layers.weight_init import weights_init_kaiming


class OcclusionUnit(nn.Module):
    def __init__(self, in_planes=2048):
        super(OcclusionUnit, self).__init__()
        self.MaxPool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.MaxPool2 = nn.MaxPool2d(kernel_size=4, stride=2, padding=0)
        self.MaxPool3 = nn.MaxPool2d(kernel_size=6, stride=2, padding=0)
        self.MaxPool4 = nn.MaxPool2d(kernel_size=8, stride=2, padding=0)
        self.mask_layer = nn.Linear(in_planes, 1, bias=True)

    def forward(self, x):
        SpaFeat1 = self.MaxPool1(x)  # shape: [n, c, h, w]
        SpaFeat2 = self.MaxPool2(x)
        SpaFeat3 = self.MaxPool3(x)
        SpaFeat4 = self.MaxPool4(x)

        Feat1 = SpaFeat1.view(SpaFeat1.size(0), SpaFeat1.size(1), SpaFeat1.size(2) * SpaFeat1.size(3))
        Feat2 = SpaFeat2.view(SpaFeat2.size(0), SpaFeat2.size(1), SpaFeat2.size(2) * SpaFeat2.size(3))
        Feat3 = SpaFeat3.view(SpaFeat3.size(0), SpaFeat3.size(1), SpaFeat3.size(2) * SpaFeat3.size(3))
        Feat4 = SpaFeat4.view(SpaFeat4.size(0), SpaFeat4.size(1), SpaFeat4.size(2) * SpaFeat4.size(3))
        SpatialFeatAll = torch.cat((Feat1, Feat2, Feat3, Feat4), 2)
        SpatialFeatAll = SpatialFeatAll.transpose(1, 2)  # shape: [n, c, m]
        y = self.mask_layer(SpatialFeatAll)
        mask_weight = torch.sigmoid(y[:, :, 0])
        # mask_score = torch.sigmoid(mask_weight[:, :48])
        feat_dim = SpaFeat1.size(2) * SpaFeat1.size(3)
        mask_score = F.normalize(mask_weight[:, :feat_dim], p=1, dim=1)
        #       mask_score_norm = mask_score
        # mask_weight_norm = torch.sigmoid(mask_weight)
        mask_weight_norm = F.normalize(mask_weight, p=1, dim=1)

        mask_score = mask_score.unsqueeze(1)

        SpaFeat1 = SpaFeat1.transpose(1, 2)
        SpaFeat1 = SpaFeat1.transpose(2, 3)  # shape: [n, h, w, c]
        SpaFeat1 = SpaFeat1.view((SpaFeat1.size(0), SpaFeat1.size(1) * SpaFeat1.size(2), -1))  # shape: [n, h*w, c]

        global_feats = mask_score.matmul(SpaFeat1).view(SpaFeat1.shape[0], -1, 1, 1)
        return global_feats, mask_weight, mask_weight_norm


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


@REID_HEADS_REGISTRY.register()
class DSRHead(EmbeddingHead):
    def __init__(self, cfg):
        super().__init__(cfg)

        feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
        with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
        norm_type = cfg.MODEL.HEADS.NORM
        num_classes = cfg.MODEL.HEADS.NUM_CLASSES
        embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
        self.occ_unit = OcclusionUnit(in_planes=feat_dim)
        self.MaxPool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.MaxPool2 = nn.MaxPool2d(kernel_size=4, stride=2, padding=0)
        self.MaxPool3 = nn.MaxPool2d(kernel_size=6, stride=2, padding=0)
        self.MaxPool4 = nn.MaxPool2d(kernel_size=8, stride=2, padding=0)

        occ_neck = []
        if embedding_dim > 0:
            occ_neck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
            feat_dim = embedding_dim

        if with_bnneck:
            occ_neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))

        self.bnneck_occ = nn.Sequential(*occ_neck)
        self.bnneck_occ.apply(weights_init_kaiming)

        self.weight_occ = nn.Parameter(torch.normal(0, 0.01, (num_classes, feat_dim)))

    def forward(self, features, targets=None):
        """
        See :class:`ReIDHeads.forward`.
        """
        SpaFeat1 = self.MaxPool1(features)  # shape: [n, c, h, w]
        SpaFeat2 = self.MaxPool2(features)
        SpaFeat3 = self.MaxPool3(features)
        SpaFeat4 = self.MaxPool4(features)

        Feat1 = SpaFeat1.view(SpaFeat1.size(0), SpaFeat1.size(1), SpaFeat1.size(2) * SpaFeat1.size(3))
        Feat2 = SpaFeat2.view(SpaFeat2.size(0), SpaFeat2.size(1), SpaFeat2.size(2) * SpaFeat2.size(3))
        Feat3 = SpaFeat3.view(SpaFeat3.size(0), SpaFeat3.size(1), SpaFeat3.size(2) * SpaFeat3.size(3))
        Feat4 = SpaFeat4.view(SpaFeat4.size(0), SpaFeat4.size(1), SpaFeat4.size(2) * SpaFeat4.size(3))
        SpatialFeatAll = torch.cat((Feat1, Feat2, Feat3, Feat4), dim=2)

        foreground_feat, mask_weight, mask_weight_norm = self.occ_unit(features)
        # print(time.time() - st)
        bn_foreground_feat = self.bnneck_occ(foreground_feat)
        bn_foreground_feat = bn_foreground_feat[..., 0, 0]

        # Evaluation
        if not self.training:
            return bn_foreground_feat, SpatialFeatAll, mask_weight_norm

        # Training
        global_feat = self.pool_layer(features)
        bn_feat = self.bottleneck(global_feat)
        bn_feat = bn_feat[..., 0, 0]

        if self.cls_layer.__class__.__name__ == 'Linear':
            pred_class_logits = F.linear(bn_feat, self.weight)
            fore_pred_class_logits = F.linear(bn_foreground_feat, self.weight_occ)
        else:
            pred_class_logits = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
            fore_pred_class_logits = F.linear(F.normalize(bn_foreground_feat), F.normalize(self.weight_occ))

        cls_outputs = self.cls_layer(pred_class_logits, targets)
        fore_cls_outputs = self.cls_layer(fore_pred_class_logits, targets)

        # pdb.set_trace()
        return {
            "cls_outputs": cls_outputs,
            "fore_cls_outputs": fore_cls_outputs,
            "pred_class_logits": pred_class_logits * self.cls_layer.s,
            "features": global_feat[..., 0, 0],
            "foreground_features": foreground_feat[..., 0, 0],
        }