short_sampler.py 1.96 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from torch.utils.data.sampler import Sampler


class ShortCycleSampler(Sampler):
    """Extend Sampler to support "short cycle" sampling.

    See paper "A Multigrid Method for Efficiently Training Video Models", Wu et
    al., 2019 (https://arxiv.org/abs/1912.00998) for details.

    Args:
        sampler (:obj: `torch.Sampler`): The default sampler to be warpped.
        batch_size (int): The batchsize before short-cycle modification.
        multi_grid_cfg (dict): The config dict for multigrid training.
        crop_size (int): The actual spatial scale.
        drop_last (bool): Whether to drop the last incomplete batch in epoch.
            Default: True.
    """

    def __init__(self,
                 sampler,
                 batch_size,
                 multigrid_cfg,
                 crop_size,
                 drop_last=True):

        self.sampler = sampler
        self.drop_last = drop_last

        bs_factor = [
            int(
                round(
                    (float(crop_size) / (s * multigrid_cfg.default_s[0]))**2))
            for s in multigrid_cfg.short_cycle_factors
        ]

        self.batch_sizes = [
            batch_size * bs_factor[0], batch_size * bs_factor[1], batch_size
        ]

    def __iter__(self):
        counter = 0
        batch_size = self.batch_sizes[0]
        batch = []
        for idx in self.sampler:
            batch.append((idx, counter % 3))
            if len(batch) == batch_size:
                yield batch
                counter += 1
                batch_size = self.batch_sizes[counter % 3]
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        avg_batch_size = sum(self.batch_sizes) / 3.0
        if self.drop_last:
            return int(np.floor(len(self.sampler) / avg_batch_size))
        else:
            return int(np.ceil(len(self.sampler) / avg_batch_size))