ssn.py 5.04 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from .. import builder
from ..builder import LOCALIZERS
from .base import BaseTAGClassifier


@LOCALIZERS.register_module()
class SSN(BaseTAGClassifier):
    """Temporal Action Detection with Structured Segment Networks.

    Args:
        backbone (dict): Config for building backbone.
        cls_head (dict): Config for building classification head.
        in_channels (int): Number of channels for input data.
            Default: 3.
        spatial_type (str): Type of spatial pooling.
            Default: 'avg'.
        dropout_ratio (float): Ratio of dropout.
            Default: 0.5.
        loss_cls (dict): Config for building loss.
            Default: ``dict(type='SSNLoss')``.
        train_cfg (dict | None): Config for training. Default: None.
        test_cfg (dict | None): Config for testing. Default: None.
    """

    def __init__(self,
                 backbone,
                 cls_head,
                 in_channels=3,
                 spatial_type='avg',
                 dropout_ratio=0.5,
                 loss_cls=dict(type='SSNLoss'),
                 train_cfg=None,
                 test_cfg=None):

        super().__init__(backbone, cls_head, train_cfg, test_cfg)

        self.is_test_prepared = False
        self.in_channels = in_channels

        self.spatial_type = spatial_type
        if self.spatial_type == 'avg':
            self.pool = nn.AvgPool2d((7, 7), stride=1, padding=0)
        elif self.spatial_type == 'max':
            self.pool = nn.MaxPool2d((7, 7), stride=1, padding=0)
        else:
            self.pool = None

        self.dropout_ratio = dropout_ratio
        if self.dropout_ratio != 0:
            self.dropout = nn.Dropout(p=self.dropout_ratio)
        else:
            self.dropout = None
        self.loss_cls = builder.build_loss(loss_cls)

    def forward_train(self, imgs, proposal_scale_factor, proposal_type,
                      proposal_labels, reg_targets, **kwargs):
        """Define the computation performed at every call when training."""
        imgs = imgs.reshape((-1, self.in_channels) + imgs.shape[4:])

        x = self.extract_feat(imgs)

        if self.pool:
            x = self.pool(x)
        if self.dropout is not None:
            x = self.dropout(x)

        activity_scores, completeness_scores, bbox_preds = self.cls_head(
            (x, proposal_scale_factor))

        loss = self.loss_cls(activity_scores, completeness_scores, bbox_preds,
                             proposal_type, proposal_labels, reg_targets,
                             self.train_cfg)
        loss_dict = dict(**loss)

        return loss_dict

    def forward_test(self, imgs, relative_proposal_list, scale_factor_list,
                     proposal_tick_list, reg_norm_consts, **kwargs):
        """Define the computation performed at every call when testing."""
        num_crops = imgs.shape[0]
        imgs = imgs.reshape((num_crops, -1, self.in_channels) + imgs.shape[3:])
        num_ticks = imgs.shape[1]

        output = []
        minibatch_size = self.test_cfg.ssn.sampler.batch_size
        for idx in range(0, num_ticks, minibatch_size):
            chunk = imgs[:, idx:idx +
                         minibatch_size, :, :, :].view((-1, ) + imgs.shape[2:])
            x = self.extract_feat(chunk)
            if self.pool:
                x = self.pool(x)
            # Merge crop to save memory.
            x = x.reshape((num_crops, x.size(0) // num_crops, -1)).mean(dim=0)
            output.append(x)
        output = torch.cat(output, dim=0)

        relative_proposal_list = relative_proposal_list.squeeze(0)
        proposal_tick_list = proposal_tick_list.squeeze(0)
        scale_factor_list = scale_factor_list.squeeze(0)
        reg_norm_consts = reg_norm_consts.squeeze(0)

        if not self.is_test_prepared:
            self.is_test_prepared = self.cls_head.prepare_test_fc(
                self.cls_head.consensus.num_multipliers)

        (output, activity_scores, completeness_scores,
         bbox_preds) = self.cls_head(
             (output, proposal_tick_list, scale_factor_list), test_mode=True)

        relative_proposal_list = relative_proposal_list.cpu().numpy()
        activity_scores = activity_scores.cpu().numpy()
        completeness_scores = completeness_scores.cpu().numpy()
        reg_norm_consts = reg_norm_consts.cpu().numpy()
        if bbox_preds is not None:
            bbox_preds = bbox_preds.view(-1, self.cls_head.num_classes, 2)
            bbox_preds[:, :, 0] = (
                bbox_preds[:, :, 0] * reg_norm_consts[1, 0] +
                reg_norm_consts[0, 0])
            bbox_preds[:, :, 1] = (
                bbox_preds[:, :, 1] * reg_norm_consts[1, 1] +
                reg_norm_consts[0, 1])
            bbox_preds = bbox_preds.cpu().numpy()

        result = [
            dict(
                relative_proposal_list=relative_proposal_list,
                activity_scores=activity_scores,
                completeness_scores=completeness_scores,
                bbox_preds=bbox_preds)
        ]

        return result