grouped_batch_sampler.py 4.26 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)
"""
import bisect
import copy
from collections import defaultdict

Aymeric Augustin's avatar
Aymeric Augustin committed
21
import numpy as np
VictorSanh's avatar
VictorSanh committed
22
23
24
25
from torch.utils.data.sampler import BatchSampler, Sampler

from utils import logger

26

VictorSanh's avatar
VictorSanh committed
27
28
29
30
31
32
def _quantize(x, bins):
    bins = copy.deepcopy(bins)
    bins = sorted(bins)
    quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
    return quantized

33

VictorSanh's avatar
VictorSanh committed
34
35
36
37
38
39
40
41
42
43
def create_lengths_groups(lengths, k=0):
    bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
    groups = _quantize(lengths, bins)
    # count number of elements per group
    counts = np.unique(groups, return_counts=True)[1]
    fbins = [0] + bins + [np.inf]
    logger.info("Using {} as bins for aspect lengths quantization".format(fbins))
    logger.info("Count of instances per bin: {}".format(counts))
    return groups

44

VictorSanh's avatar
VictorSanh committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class GroupedBatchSampler(BatchSampler):
    """
    Wraps another sampler to yield a mini-batch of indices.
    It enforces that the batch only contain elements from the same group.
    It also tries to provide mini-batches which follows an ordering which is
    as close as possible to the ordering from the original sampler.
    Arguments:
        sampler (Sampler): Base sampler.
        group_ids (list[int]): If the sampler produces indices in range [0, N),
            `group_ids` must be a list of `N` ints which contains the group id of each sample.
            The group ids must be a continuous set of integers starting from
            0, i.e. they must be in the range [0, num_groups).
        batch_size (int): Size of mini-batch.
    """
59

VictorSanh's avatar
VictorSanh committed
60
61
62
    def __init__(self, sampler, group_ids, batch_size):
        if not isinstance(sampler, Sampler):
            raise ValueError(
63
                "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
VictorSanh's avatar
VictorSanh committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            )
        self.sampler = sampler
        self.group_ids = group_ids
        self.batch_size = batch_size

    def __iter__(self):
        buffer_per_group = defaultdict(list)
        samples_per_group = defaultdict(list)

        num_batches = 0
        for idx in self.sampler:
            group_id = self.group_ids[idx]
            buffer_per_group[group_id].append(idx)
            samples_per_group[group_id].append(idx)
            if len(buffer_per_group[group_id]) == self.batch_size:
79
                yield buffer_per_group[group_id]  # TODO
VictorSanh's avatar
VictorSanh committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
                num_batches += 1
                del buffer_per_group[group_id]
            assert len(buffer_per_group[group_id]) < self.batch_size

        # now we have run out of elements that satisfy
        # the group criteria, let's return the remaining
        # elements so that the size of the sampler is
        # deterministic
        expected_num_batches = len(self)
        num_remaining = expected_num_batches - num_batches
        if num_remaining > 0:
            # for the remaining batches, group the batches by similar lengths
            batch_idx = []
            for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
                batch_idx.extend(idxs)
                if len(batch_idx) >= self.batch_size:
96
97
                    yield batch_idx[: self.batch_size]
                    batch_idx = batch_idx[self.batch_size :]
VictorSanh's avatar
VictorSanh committed
98
99
100
101
102
103
104
105
106
107
108
                    num_remaining -= 1
            if len(batch_idx) > 0:
                yield batch_idx
                num_remaining -= 1
        assert num_remaining == 0

    def __len__(self):
        """
        Return the number of mini-batches rather than the number of samples.
        """
        return (len(self.sampler) + self.batch_size - 1) // self.batch_size