longshortcyclehook.py 10.5 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import Hook
from mmcv.runner.hooks.lr_updater import LrUpdaterHook, StepLrUpdaterHook
from torch.nn.modules.utils import _ntuple

from mmaction.core.lr import RelativeStepLrUpdaterHook
from mmaction.utils import get_root_logger


def modify_subbn3d_num_splits(logger, module, num_splits):
    """Recursively modify the number of splits of subbn3ds in module.
    Inheritates the running_mean and running_var from last subbn.bn.

    Args:
        logger (:obj:`logging.Logger`): The logger to log information.
        module (nn.Module): The module to be modified.
        num_splits (int): The targeted number of splits.
    Returns:
        int: The number of subbn3d modules modified.
    """
    count = 0
    for child in module.children():
        from mmaction.models import SubBatchNorm3D
        if isinstance(child, SubBatchNorm3D):
            new_split_bn = nn.BatchNorm3d(
                child.num_features * num_splits, affine=False).cuda()
            new_state_dict = new_split_bn.state_dict()

            for param_name, param in child.bn.state_dict().items():
                origin_param_shape = param.size()
                new_param_shape = new_state_dict[param_name].size()
                if len(origin_param_shape) == 1 and len(
                        new_param_shape
                ) == 1 and new_param_shape[0] >= origin_param_shape[
                        0] and new_param_shape[0] % origin_param_shape[0] == 0:
                    # weight bias running_var running_mean
                    new_state_dict[param_name] = torch.cat(
                        [param] *
                        (new_param_shape[0] // origin_param_shape[0]))
                else:
                    logger.info(f'skip  {param_name}')

            child.num_splits = num_splits
            new_split_bn.load_state_dict(new_state_dict)
            child.split_bn = new_split_bn
            count += 1
        else:
            count += modify_subbn3d_num_splits(logger, child, num_splits)
    return count


class LongShortCycleHook(Hook):
    """A multigrid method for efficiently training video models.

    This hook defines multigrid training schedule and update cfg
        accordingly, which is proposed in `A Multigrid Method for Efficiently
        Training Video Models <https://arxiv.org/abs/1912.00998>`_.

    Args:
        cfg (:obj:`mmcv.ConfigDictg`): The whole config for the experiment.
    """

    def __init__(self, cfg):
        self.cfg = cfg
        self.multi_grid_cfg = cfg.get('multigrid', None)
        self.data_cfg = cfg.get('data', None)
        assert (self.multi_grid_cfg is not None and self.data_cfg is not None)
        self.logger = get_root_logger()
        self.logger.info(self.multi_grid_cfg)

    def before_run(self, runner):
        """Called before running, change the StepLrUpdaterHook to
        RelativeStepLrHook."""
        self._init_schedule(runner, self.multi_grid_cfg, self.data_cfg)
        steps = []
        steps = [s[-1] for s in self.schedule]
        steps.insert(-1, (steps[-2] + steps[-1]) // 2)  # add finetune stage
        for index, hook in enumerate(runner.hooks):
            if isinstance(hook, StepLrUpdaterHook):
                base_lr = hook.base_lr[0]
                gamma = hook.gamma
                lrs = [base_lr * gamma**s[0] * s[1][0] for s in self.schedule]
                lrs = lrs[:-1] + [lrs[-2], lrs[-1] * gamma
                                  ]  # finetune-stage lrs
                new_hook = RelativeStepLrUpdaterHook(runner, steps, lrs)
                runner.hooks[index] = new_hook

    def before_train_epoch(self, runner):
        """Before training epoch, update the runner based on long-cycle
        schedule."""
        self._update_long_cycle(runner)

    def _update_long_cycle(self, runner):
        """Before every epoch, check if long cycle shape should change. If it
        should, change the pipelines accordingly.

        change dataloader and model's subbn3d(split_bn)
        """
        base_b, base_t, base_s = self._get_schedule(runner.epoch)

        # rebuild dataset
        from mmaction.datasets import build_dataset
        resize_list = []
        for trans in self.cfg.data.train.pipeline:
            if trans['type'] == 'SampleFrames':
                curr_t = trans['clip_len']
                trans['clip_len'] = base_t
                trans['frame_interval'] = (curr_t *
                                           trans['frame_interval']) / base_t
            elif trans['type'] == 'Resize':
                resize_list.append(trans)
        resize_list[-1]['scale'] = _ntuple(2)(base_s)

        ds = build_dataset(self.cfg.data.train)

        from mmaction.datasets import build_dataloader

        dataloader = build_dataloader(
            ds,
            self.data_cfg.videos_per_gpu * base_b,
            self.data_cfg.workers_per_gpu,
            dist=True,
            num_gpus=len(self.cfg.gpu_ids),
            drop_last=True,
            seed=self.cfg.get('seed', None),
        )
        runner.data_loader = dataloader
        self.logger.info('Rebuild runner.data_loader')

        # the self._max_epochs is changed, therefore update here
        runner._max_iters = runner._max_epochs * len(runner.data_loader)

        # rebuild all the sub_batch_bn layers
        num_modifies = modify_subbn3d_num_splits(self.logger, runner.model,
                                                 base_b)
        self.logger.info(f'{num_modifies} subbns modified to {base_b}.')

    def _get_long_cycle_schedule(self, runner, cfg):
        # `schedule` is a list of [step_index, base_shape, epochs]
        schedule = []
        avg_bs = []
        all_shapes = []
        self.default_size = self.default_t * self.default_s**2
        for t_factor, s_factor in cfg.long_cycle_factors:
            base_t = int(round(self.default_t * t_factor))
            base_s = int(round(self.default_s * s_factor))
            if cfg.short_cycle:
                shapes = [[
                    base_t,
                    int(round(self.default_s * cfg.short_cycle_factors[0]))
                ],
                          [
                              base_t,
                              int(
                                  round(self.default_s *
                                        cfg.short_cycle_factors[1]))
                          ], [base_t, base_s]]
            else:
                shapes = [[base_t, base_s]]
            # calculate the batchsize, shape = [batchsize, #frames, scale]
            shapes = [[
                int(round(self.default_size / (s[0] * s[1]**2))), s[0], s[1]
            ] for s in shapes]
            avg_bs.append(np.mean([s[0] for s in shapes]))
            all_shapes.append(shapes)

        for hook in runner.hooks:
            if isinstance(hook, LrUpdaterHook):
                if isinstance(hook, StepLrUpdaterHook):
                    steps = hook.step if isinstance(hook.step,
                                                    list) else [hook.step]
                    steps = [0] + steps
                    break
                else:
                    raise NotImplementedError(
                        'Only step scheduler supports multi grid now')
            else:
                pass
        total_iters = 0
        default_iters = steps[-1]
        for step_index in range(len(steps) - 1):
            # except the final step
            step_epochs = steps[step_index + 1] - steps[step_index]
            # number of epochs for this step
            for long_cycle_index, shapes in enumerate(all_shapes):
                cur_epochs = (
                    step_epochs * avg_bs[long_cycle_index] / sum(avg_bs))
                cur_iters = cur_epochs / avg_bs[long_cycle_index]
                total_iters += cur_iters
                schedule.append((step_index, shapes[-1], cur_epochs))
        iter_saving = default_iters / total_iters
        final_step_epochs = runner.max_epochs - steps[-1]
        # the fine-tuning phase to have the same amount of iteration
        # saving as the rest of the training
        ft_epochs = final_step_epochs / iter_saving * avg_bs[-1]
        # in `schedule` we ignore the shape of ShortCycle
        schedule.append((step_index + 1, all_shapes[-1][-1], ft_epochs))

        x = (
            runner.max_epochs * cfg.epoch_factor / sum(s[-1]
                                                       for s in schedule))
        runner._max_epochs = int(runner._max_epochs * cfg.epoch_factor)
        final_schedule = []
        total_epochs = 0
        for s in schedule:
            # extend the epochs by `factor`
            epochs = s[2] * x
            total_epochs += epochs
            final_schedule.append((s[0], s[1], int(round(total_epochs))))
        self.logger.info(final_schedule)
        return final_schedule

    def _print_schedule(self, schedule):
        """logging the schedule."""
        self.logger.info('\tLongCycleId\tBase shape\tEpochs\t')
        for s in schedule:
            self.logger.info(f'\t{s[0]}\t{s[1]}\t{s[2]}\t')

    def _get_schedule(self, epoch):
        """Returning the corresponding shape."""
        for s in self.schedule:
            if epoch < s[-1]:
                return s[1]
        return self.schedule[-1][1]

    def _init_schedule(self, runner, multi_grid_cfg, data_cfg):
        """Initialize the multigrid shcedule.

        Args:
            runner (:obj: `mmcv.Runner`): The runner within which to train.
            multi_grid_cfg (:obj: `mmcv.ConfigDict`): The multigrid config.
            data_cfg (:obj: `mmcv.ConfigDict`): The data config.
        """
        self.default_bs = data_cfg.videos_per_gpu
        data_cfg = data_cfg.get('train', None)
        final_resize_cfg = [
            aug for aug in data_cfg.pipeline if aug.type == 'Resize'
        ][-1]
        if isinstance(final_resize_cfg.scale, tuple):
            # Assume square image
            if max(final_resize_cfg.scale) == min(final_resize_cfg.scale):
                self.default_s = max(final_resize_cfg.scale)
            else:
                raise NotImplementedError('non-square scale not considered.')
        sample_frame_cfg = [
            aug for aug in data_cfg.pipeline if aug.type == 'SampleFrames'
        ][0]
        self.default_t = sample_frame_cfg.clip_len

        if multi_grid_cfg.long_cycle:
            self.schedule = self._get_long_cycle_schedule(
                runner, multi_grid_cfg)
        else:
            raise ValueError('There should be at least long cycle.')