ssn_head.py 16.4 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from ..builder import HEADS


def parse_stage_config(stage_cfg):
    """Parse config of STPP for three stages.

    Args:
        stage_cfg (int | tuple[int]):
            Config of structured temporal pyramid pooling.

    Returns:
        tuple[tuple[int], int]:
            Config of structured temporal pyramid pooling and
            total number of parts(number of multipliers).
    """
    if isinstance(stage_cfg, int):
        return (stage_cfg, ), stage_cfg
    if isinstance(stage_cfg, tuple):
        return stage_cfg, sum(stage_cfg)
    raise ValueError(f'Incorrect STPP config {stage_cfg}')


class STPPTrain(nn.Module):
    """Structured temporal pyramid pooling for SSN at training.

    Args:
        stpp_stage (tuple): Config of structured temporal pyramid pooling.
            Default: (1, (1, 2), 1).
        num_segments_list (tuple): Number of segments to be sampled
            in three stages. Default: (2, 5, 2).
    """

    def __init__(self, stpp_stage=(1, (1, 2), 1), num_segments_list=(2, 5, 2)):
        super().__init__()

        starting_part, starting_multiplier = parse_stage_config(stpp_stage[0])
        course_part, course_multiplier = parse_stage_config(stpp_stage[1])
        ending_part, ending_multiplier = parse_stage_config(stpp_stage[2])

        self.num_multipliers = (
            starting_multiplier + course_multiplier + ending_multiplier)
        self.stpp_stages = (starting_part, course_part, ending_part)
        self.multiplier_list = (starting_multiplier, course_multiplier,
                                ending_multiplier)

        self.num_segments_list = num_segments_list

    @staticmethod
    def _extract_stage_feature(stage_feat, stage_parts, num_multipliers,
                               scale_factors, num_samples):
        """Extract stage feature based on structured temporal pyramid pooling.

        Args:
            stage_feat (torch.Tensor): Stage features to be STPP.
            stage_parts (tuple): Config of STPP.
            num_multipliers (int): Total number of parts in the stage.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.
            num_samples (int): Number of samples.

        Returns:
            torch.Tensor: Features of the stage.
        """
        stage_stpp_feat = []
        stage_len = stage_feat.size(1)
        for stage_part in stage_parts:
            ticks = torch.arange(0, stage_len + 1e-5,
                                 stage_len / stage_part).int()
            for i in range(stage_part):
                part_feat = stage_feat[:, ticks[i]:ticks[i + 1], :].mean(
                    dim=1) / num_multipliers
                if scale_factors is not None:
                    part_feat = (
                        part_feat * scale_factors.view(num_samples, 1))
                stage_stpp_feat.append(part_feat)
        return stage_stpp_feat

    def forward(self, x, scale_factors):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.

        Returns:
            tuple[torch.Tensor, torch.Tensor]:
                Features for predicting activity scores and
                completeness scores.
        """
        x0 = self.num_segments_list[0]
        x1 = x0 + self.num_segments_list[1]
        num_segments = x1 + self.num_segments_list[2]

        feat_dim = x.size(1)
        x = x.view(-1, num_segments, feat_dim)
        num_samples = x.size(0)

        scale_factors = scale_factors.view(-1, 2)

        stage_stpp_feats = []
        stage_stpp_feats.extend(
            self._extract_stage_feature(x[:, :x0, :], self.stpp_stages[0],
                                        self.multiplier_list[0],
                                        scale_factors[:, 0], num_samples))
        stage_stpp_feats.extend(
            self._extract_stage_feature(x[:, x0:x1, :], self.stpp_stages[1],
                                        self.multiplier_list[1], None,
                                        num_samples))
        stage_stpp_feats.extend(
            self._extract_stage_feature(x[:, x1:, :], self.stpp_stages[2],
                                        self.multiplier_list[2],
                                        scale_factors[:, 1], num_samples))
        stpp_feat = torch.cat(stage_stpp_feats, dim=1)

        course_feat = x[:, x0:x1, :].mean(dim=1)
        return course_feat, stpp_feat


class STPPTest(nn.Module):
    """Structured temporal pyramid pooling for SSN at testing.

    Args:
        num_classes (int): Number of classes to be classified.
        use_regression (bool): Whether to perform regression or not.
            Default: True.
        stpp_stage (tuple): Config of structured temporal pyramid pooling.
            Default: (1, (1, 2), 1).
    """

    def __init__(self,
                 num_classes,
                 use_regression=True,
                 stpp_stage=(1, (1, 2), 1)):
        super().__init__()

        self.activity_score_len = num_classes + 1
        self.complete_score_len = num_classes
        self.reg_score_len = num_classes * 2
        self.use_regression = use_regression

        starting_parts, starting_multiplier = parse_stage_config(stpp_stage[0])
        course_parts, course_multiplier = parse_stage_config(stpp_stage[1])
        ending_parts, ending_multiplier = parse_stage_config(stpp_stage[2])

        self.num_multipliers = (
            starting_multiplier + course_multiplier + ending_multiplier)
        if self.use_regression:
            self.feat_dim = (
                self.activity_score_len + self.num_multipliers *
                (self.complete_score_len + self.reg_score_len))
        else:
            self.feat_dim = (
                self.activity_score_len +
                self.num_multipliers * self.complete_score_len)
        self.stpp_stage = (starting_parts, course_parts, ending_parts)

        self.activity_slice = slice(0, self.activity_score_len)
        self.complete_slice = slice(
            self.activity_slice.stop, self.activity_slice.stop +
            self.complete_score_len * self.num_multipliers)
        self.reg_slice = slice(
            self.complete_slice.stop, self.complete_slice.stop +
            self.reg_score_len * self.num_multipliers)

    @staticmethod
    def _pyramids_pooling(out_scores, index, raw_scores, ticks, scale_factors,
                          score_len, stpp_stage):
        """Perform pyramids pooling.

        Args:
            out_scores (torch.Tensor): Scores to be returned.
            index (int): Index of output scores.
            raw_scores (torch.Tensor): Raw scores before STPP.
            ticks (list): Ticks of raw scores.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.
            score_len (int): Length of the score.
            stpp_stage (tuple): Config of STPP.
        """
        offset = 0
        for stage_idx, stage_cfg in enumerate(stpp_stage):
            if stage_idx == 0:
                scale_factor = scale_factors[0]
            elif stage_idx == len(stpp_stage) - 1:
                scale_factor = scale_factors[1]
            else:
                scale_factor = 1.0

            sum_parts = sum(stage_cfg)
            tick_left = ticks[stage_idx]
            tick_right = float(max(ticks[stage_idx] + 1, ticks[stage_idx + 1]))

            if tick_right <= 0 or tick_left >= raw_scores.size(0):
                offset += sum_parts
                continue
            for num_parts in stage_cfg:
                part_ticks = torch.arange(tick_left, tick_right + 1e-5,
                                          (tick_right - tick_left) /
                                          num_parts).int()

                for i in range(num_parts):
                    part_tick_left = part_ticks[i]
                    part_tick_right = part_ticks[i + 1]
                    if part_tick_right - part_tick_left >= 1:
                        raw_score = raw_scores[part_tick_left:part_tick_right,
                                               offset *
                                               score_len:(offset + 1) *
                                               score_len]
                        raw_scale_score = raw_score.mean(dim=0) * scale_factor
                        out_scores[index, :] += raw_scale_score.detach().cpu()
                    offset += 1

        return out_scores

    def forward(self, x, proposal_ticks, scale_factors):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.
            proposal_ticks (list): Ticks of proposals to be STPP.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                out_activity_scores (torch.Tensor): Activity scores
                out_complete_scores (torch.Tensor): Completeness scores.
                out_reg_scores (torch.Tensor): Regression scores.
        """
        assert x.size(1) == self.feat_dim
        num_ticks = proposal_ticks.size(0)

        out_activity_scores = torch.zeros((num_ticks, self.activity_score_len),
                                          dtype=x.dtype)
        raw_activity_scores = x[:, self.activity_slice]

        out_complete_scores = torch.zeros((num_ticks, self.complete_score_len),
                                          dtype=x.dtype)
        raw_complete_scores = x[:, self.complete_slice]

        if self.use_regression:
            out_reg_scores = torch.zeros((num_ticks, self.reg_score_len),
                                         dtype=x.dtype)
            raw_reg_scores = x[:, self.reg_slice]
        else:
            out_reg_scores = None
            raw_reg_scores = None

        for i in range(num_ticks):
            ticks = proposal_ticks[i]

            out_activity_scores[i, :] = raw_activity_scores[
                ticks[1]:max(ticks[1] + 1, ticks[2]), :].mean(dim=0)

            out_complete_scores = self._pyramids_pooling(
                out_complete_scores, i, raw_complete_scores, ticks,
                scale_factors[i], self.complete_score_len, self.stpp_stage)

            if self.use_regression:
                out_reg_scores = self._pyramids_pooling(
                    out_reg_scores, i, raw_reg_scores, ticks, scale_factors[i],
                    self.reg_score_len, self.stpp_stage)

        return out_activity_scores, out_complete_scores, out_reg_scores


@HEADS.register_module()
class SSNHead(nn.Module):
    """The classification head for SSN.

    Args:
        dropout_ratio (float): Probability of dropout layer. Default: 0.8.
        in_channels (int): Number of channels for input data. Default: 1024.
        num_classes (int): Number of classes to be classified. Default: 20.
        consensus (dict): Config of segmental consensus.
        use_regression (bool): Whether to perform regression or not.
            Default: True.
        init_std (float): Std value for Initiation. Default: 0.001.
    """

    def __init__(self,
                 dropout_ratio=0.8,
                 in_channels=1024,
                 num_classes=20,
                 consensus=dict(
                     type='STPPTrain',
                     standalong_classifier=True,
                     stpp_cfg=(1, 1, 1),
                     num_seg=(2, 5, 2)),
                 use_regression=True,
                 init_std=0.001):

        super().__init__()

        self.dropout_ratio = dropout_ratio
        self.num_classes = num_classes
        self.use_regression = use_regression
        self.init_std = init_std

        if self.dropout_ratio != 0:
            self.dropout = nn.Dropout(p=self.dropout_ratio)
        else:
            self.dropout = None

        # Based on this copy, the model will utilize different
        # structured temporal pyramid pooling at training and testing.
        # Warning: this copy cannot be removed.
        consensus_ = consensus.copy()
        consensus_type = consensus_.pop('type')
        if consensus_type == 'STPPTrain':
            self.consensus = STPPTrain(**consensus_)
        elif consensus_type == 'STPPTest':
            consensus_['num_classes'] = self.num_classes
            self.consensus = STPPTest(**consensus_)

        self.in_channels_activity = in_channels
        self.in_channels_complete = (
            self.consensus.num_multipliers * in_channels)
        self.activity_fc = nn.Linear(in_channels, num_classes + 1)
        self.completeness_fc = nn.Linear(self.in_channels_complete,
                                         num_classes)
        if self.use_regression:
            self.regressor_fc = nn.Linear(self.in_channels_complete,
                                          num_classes * 2)

    def init_weights(self):
        """Initiate the parameters from scratch."""
        normal_init(self.activity_fc, std=self.init_std)
        normal_init(self.completeness_fc, std=self.init_std)
        if self.use_regression:
            normal_init(self.regressor_fc, std=self.init_std)

    def prepare_test_fc(self, stpp_feat_multiplier):
        """Reorganize the shape of fully connected layer at testing, in order
        to improve testing efficiency.

        Args:
            stpp_feat_multiplier (int): Total number of parts.

        Returns:
            bool: Whether the shape transformation is ready for testing.
        """

        in_features = self.activity_fc.in_features
        out_features = (
            self.activity_fc.out_features +
            self.completeness_fc.out_features * stpp_feat_multiplier)
        if self.use_regression:
            out_features += (
                self.regressor_fc.out_features * stpp_feat_multiplier)
        self.test_fc = nn.Linear(in_features, out_features)

        # Fetch weight and bias of the reorganized fc.
        complete_weight = self.completeness_fc.weight.data.view(
            self.completeness_fc.out_features, stpp_feat_multiplier,
            in_features).transpose(0, 1).contiguous().view(-1, in_features)
        complete_bias = self.completeness_fc.bias.data.view(1, -1).expand(
            stpp_feat_multiplier, self.completeness_fc.out_features
        ).contiguous().view(-1) / stpp_feat_multiplier

        weight = torch.cat((self.activity_fc.weight.data, complete_weight))
        bias = torch.cat((self.activity_fc.bias.data, complete_bias))

        if self.use_regression:
            reg_weight = self.regressor_fc.weight.data.view(
                self.regressor_fc.out_features, stpp_feat_multiplier,
                in_features).transpose(0,
                                       1).contiguous().view(-1, in_features)
            reg_bias = self.regressor_fc.bias.data.view(1, -1).expand(
                stpp_feat_multiplier, self.regressor_fc.out_features
            ).contiguous().view(-1) / stpp_feat_multiplier
            weight = torch.cat((weight, reg_weight))
            bias = torch.cat((bias, reg_bias))

        self.test_fc.weight.data = weight
        self.test_fc.bias.data = bias
        return True

    def forward(self, x, test_mode=False):
        """Defines the computation performed at every call."""
        if not test_mode:
            x, proposal_scale_factor = x
            activity_feat, completeness_feat = self.consensus(
                x, proposal_scale_factor)

            if self.dropout is not None:
                activity_feat = self.dropout(activity_feat)
                completeness_feat = self.dropout(completeness_feat)

            activity_scores = self.activity_fc(activity_feat)
            complete_scores = self.completeness_fc(completeness_feat)
            if self.use_regression:
                bbox_preds = self.regressor_fc(completeness_feat)
                bbox_preds = bbox_preds.view(-1,
                                             self.completeness_fc.out_features,
                                             2)
            else:
                bbox_preds = None
            return activity_scores, complete_scores, bbox_preds

        x, proposal_tick_list, scale_factor_list = x
        test_scores = self.test_fc(x)
        (activity_scores, completeness_scores,
         bbox_preds) = self.consensus(test_scores, proposal_tick_list,
                                      scale_factor_list)

        return (test_scores, activity_scores, completeness_scores, bbox_preds)