gradcam_utils.py 9.55 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F


class GradCAM:
    """GradCAM class helps create visualization results.

    Visualization results are blended by heatmaps and input images.
    This class is modified from
    https://github.com/facebookresearch/SlowFast/blob/master/slowfast/visualization/gradcam_utils.py # noqa
    For more information about GradCAM, please visit:
    https://arxiv.org/pdf/1610.02391.pdf
    """

    def __init__(self, model, target_layer_name, colormap='viridis'):
        """Create GradCAM class with recognizer, target layername & colormap.

        Args:
            model (nn.Module): the recognizer model to be used.
            target_layer_name (str): name of convolutional layer to
                be used to get gradients and feature maps from for creating
                localization maps.
            colormap (Optional[str]): matplotlib colormap used to create
                heatmap. Default: 'viridis'. For more information, please visit
                https://matplotlib.org/3.3.0/tutorials/colors/colormaps.html
        """
        from ..models.recognizers import Recognizer2D, Recognizer3D
        if isinstance(model, Recognizer2D):
            self.is_recognizer2d = True
        elif isinstance(model, Recognizer3D):
            self.is_recognizer2d = False
        else:
            raise ValueError(
                'GradCAM utils only support Recognizer2D & Recognizer3D.')

        self.model = model
        self.model.eval()
        self.target_gradients = None
        self.target_activations = None

        import matplotlib.pyplot as plt
        self.colormap = plt.get_cmap(colormap)
        self.data_mean = torch.tensor(model.cfg.img_norm_cfg['mean'])
        self.data_std = torch.tensor(model.cfg.img_norm_cfg['std'])
        self._register_hooks(target_layer_name)

    def _register_hooks(self, layer_name):
        """Register forward and backward hook to a layer, given layer_name, to
        obtain gradients and activations.

        Args:
            layer_name (str): name of the layer.
        """

        def get_gradients(module, grad_input, grad_output):
            self.target_gradients = grad_output[0].detach()

        def get_activations(module, input, output):
            self.target_activations = output.clone().detach()

        layer_ls = layer_name.split('/')
        prev_module = self.model
        for layer in layer_ls:
            prev_module = prev_module._modules[layer]

        target_layer = prev_module
        target_layer.register_forward_hook(get_activations)
        target_layer.register_backward_hook(get_gradients)

    def _calculate_localization_map(self, inputs, use_labels, delta=1e-20):
        """Calculate localization map for all inputs with Grad-CAM.

        Args:
            inputs (dict): model inputs, generated by test pipeline,
                at least including two keys, ``imgs`` and ``label``.
            use_labels (bool): Whether to use given labels to generate
                localization map. Labels are in ``inputs['label']``.
            delta (float): used in localization map normalization,
                must be small enough. Please make sure
                `localization_map_max - localization_map_min >> delta`
        Returns:
            tuple[torch.Tensor, torch.Tensor]: (localization_map, preds)
                localization_map (torch.Tensor): the localization map for
                    input imgs.
                preds (torch.Tensor): Model predictions for `inputs` with
                    shape (batch_size, num_classes).
        """
        inputs['imgs'] = inputs['imgs'].clone()

        # use score before softmax
        self.model.test_cfg['average_clips'] = 'score'
        # model forward & backward
        preds = self.model(gradcam=True, **inputs)
        if use_labels:
            labels = inputs['label']
            if labels.ndim == 1:
                labels = labels.unsqueeze(-1)
            score = torch.gather(preds, dim=1, index=labels)
        else:
            score = torch.max(preds, dim=-1)[0]
        self.model.zero_grad()
        score = torch.sum(score)
        score.backward()

        if self.is_recognizer2d:
            # [batch_size, num_segments, 3, H, W]
            b, t, _, h, w = inputs['imgs'].size()
        else:
            # [batch_size, num_crops*num_clips, 3, clip_len, H, W]
            b1, b2, _, t, h, w = inputs['imgs'].size()
            b = b1 * b2

        gradients = self.target_gradients
        activations = self.target_activations
        if self.is_recognizer2d:
            # [B*Tg, C', H', W']
            b_tg, c, _, _ = gradients.size()
            tg = b_tg // b
        else:
            # source shape: [B, C', Tg, H', W']
            _, c, tg, _, _ = gradients.size()
            # target shape: [B, Tg, C', H', W']
            gradients = gradients.permute(0, 2, 1, 3, 4)
            activations = activations.permute(0, 2, 1, 3, 4)

        # calculate & resize to [B, 1, T, H, W]
        weights = torch.mean(gradients.view(b, tg, c, -1), dim=3)
        weights = weights.view(b, tg, c, 1, 1)
        activations = activations.view([b, tg, c] +
                                       list(activations.size()[-2:]))
        localization_map = torch.sum(
            weights * activations, dim=2, keepdim=True)
        localization_map = F.relu(localization_map)
        localization_map = localization_map.permute(0, 2, 1, 3, 4)
        localization_map = F.interpolate(
            localization_map,
            size=(t, h, w),
            mode='trilinear',
            align_corners=False)

        # Normalize the localization map.
        localization_map_min, localization_map_max = (
            torch.min(localization_map.view(b, -1), dim=-1, keepdim=True)[0],
            torch.max(localization_map.view(b, -1), dim=-1, keepdim=True)[0])
        localization_map_min = torch.reshape(
            localization_map_min, shape=(b, 1, 1, 1, 1))
        localization_map_max = torch.reshape(
            localization_map_max, shape=(b, 1, 1, 1, 1))
        localization_map = (localization_map - localization_map_min) / (
            localization_map_max - localization_map_min + delta)
        localization_map = localization_map.data

        return localization_map.squeeze(dim=1), preds

    def _alpha_blending(self, localization_map, input_imgs, alpha):
        """Blend heatmaps and model input images and get visulization results.

        Args:
            localization_map (torch.Tensor): localization map for all inputs,
                generated with Grad-CAM
            input_imgs (torch.Tensor): model inputs, normed images.
            alpha (float): transparency level of the heatmap,
                in the range [0, 1].
        Returns:
            torch.Tensor: blending results for localization map and input
                images, with shape [B, T, H, W, 3] and pixel values in
                RGB order within range [0, 1].
        """
        # localization_map shape [B, T, H, W]
        localization_map = localization_map.cpu()

        # heatmap shape [B, T, H, W, 3] in RGB order
        heatmap = self.colormap(localization_map.detach().numpy())
        heatmap = heatmap[:, :, :, :, :3]
        heatmap = torch.from_numpy(heatmap)

        # Permute input imgs to [B, T, H, W, 3], like heatmap
        if self.is_recognizer2d:
            # Recognizer2D input (B, T, C, H, W)
            curr_inp = input_imgs.permute(0, 1, 3, 4, 2)
        else:
            # Recognizer3D input (B', num_clips*num_crops, C, T, H, W)
            # B = B' * num_clips * num_crops
            curr_inp = input_imgs.view([-1] + list(input_imgs.size()[2:]))
            curr_inp = curr_inp.permute(0, 2, 3, 4, 1)

        # renormalize input imgs to [0, 1]
        curr_inp = curr_inp.cpu()
        curr_inp *= self.data_std
        curr_inp += self.data_mean
        curr_inp /= 255.

        # alpha blending
        blended_imgs = alpha * heatmap + (1 - alpha) * curr_inp

        return blended_imgs

    def __call__(self, inputs, use_labels=False, alpha=0.5):
        """Visualize the localization maps on their corresponding inputs as
        heatmap, using Grad-CAM.

        Generate visualization results for **ALL CROPS**.
        For example, for I3D model, if `clip_len=32, num_clips=10` and
        use `ThreeCrop` in test pipeline, then for every model inputs,
        there are 960(32*10*3) images generated.

        Args:
            inputs (dict): model inputs, generated by test pipeline,
                at least including two keys, ``imgs`` and ``label``.
            use_labels (bool): Whether to use given labels to generate
                localization map. Labels are in ``inputs['label']``.
            alpha (float): transparency level of the heatmap,
                in the range [0, 1].
        Returns:
            blended_imgs (torch.Tensor): Visualization results, blended by
                localization maps and model inputs.
            preds (torch.Tensor): Model predictions for inputs.
        """

        # localization_map shape [B, T, H, W]
        # preds shape [batch_size, num_classes]
        localization_map, preds = self._calculate_localization_map(
            inputs, use_labels=use_labels)

        # blended_imgs shape [B, T, H, W, 3]
        blended_imgs = self._alpha_blending(localization_map, inputs['imgs'],
                                            alpha)

        # blended_imgs shape [B, T, H, W, 3]
        # preds shape [batch_size, num_classes]
        # Recognizer2D: B = batch_size, T = num_segments
        # Recognizer3D: B = batch_size * num_crops * num_clips, T = clip_len
        return blended_imgs, preds