cutmix.py 6.91 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
unknown's avatar
unknown committed
2
3
4
5
6
7
from abc import ABCMeta, abstractmethod

import numpy as np
import torch

from .builder import AUGMENT
8
from .utils import one_hot_encoding
unknown's avatar
unknown committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


class BaseCutMixLayer(object, metaclass=ABCMeta):
    """Base class for CutMixLayer.

    Args:
        alpha (float): Parameters for Beta distribution. Positive(>0)
        num_classes (int): The number of classes
        prob (float): MixUp probability. It should be in range [0, 1].
            Default to 1.0
        cutmix_minmax (List[float], optional): cutmix min/max image ratio.
            (as percent of image size). When cutmix_minmax is not None, we
            generate cutmix bounding-box using cutmix_minmax instead of alpha
        correct_lam (bool): Whether to apply lambda correction when cutmix bbox
            clipped by image borders. Default to True
    """

    def __init__(self,
                 alpha,
                 num_classes,
                 prob=1.0,
                 cutmix_minmax=None,
                 correct_lam=True):
        super(BaseCutMixLayer, self).__init__()

        assert isinstance(alpha, float) and alpha > 0
        assert isinstance(num_classes, int)
        assert isinstance(prob, float) and 0.0 <= prob <= 1.0

        self.alpha = alpha
        self.num_classes = num_classes
        self.prob = prob
        self.cutmix_minmax = cutmix_minmax
        self.correct_lam = correct_lam

    def rand_bbox_minmax(self, img_shape, count=None):
        """Min-Max CutMix bounding-box Inspired by Darknet cutmix
        implementation. It generates a random rectangular bbox based on min/max
        percent values applied to each dimension of the input image.

        Typical defaults for minmax are usually in the  .2-.3 for min and
        .8-.9 range for max.

        Args:
            img_shape (tuple): Image shape as tuple
            count (int, optional): Number of bbox to generate. Default to None
        """
        assert len(self.cutmix_minmax) == 2
        img_h, img_w = img_shape[-2:]
        cut_h = np.random.randint(
            int(img_h * self.cutmix_minmax[0]),
            int(img_h * self.cutmix_minmax[1]),
            size=count)
        cut_w = np.random.randint(
            int(img_w * self.cutmix_minmax[0]),
            int(img_w * self.cutmix_minmax[1]),
            size=count)
        yl = np.random.randint(0, img_h - cut_h, size=count)
        xl = np.random.randint(0, img_w - cut_w, size=count)
        yu = yl + cut_h
        xu = xl + cut_w
        return yl, yu, xl, xu

    def rand_bbox(self, img_shape, lam, margin=0., count=None):
        """Standard CutMix bounding-box that generates a random square bbox
        based on lambda value. This implementation includes support for
        enforcing a border margin as percent of bbox dimensions.

        Args:
            img_shape (tuple): Image shape as tuple
            lam (float): Cutmix lambda value
            margin (float): Percentage of bbox dimension to enforce as margin
                (reduce amount of box outside image). Default to 0.
            count (int, optional): Number of bbox to generate. Default to None
        """
        ratio = np.sqrt(1 - lam)
        img_h, img_w = img_shape[-2:]
        cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
        margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
        cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
        cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
        yl = np.clip(cy - cut_h // 2, 0, img_h)
        yh = np.clip(cy + cut_h // 2, 0, img_h)
        xl = np.clip(cx - cut_w // 2, 0, img_w)
        xh = np.clip(cx + cut_w // 2, 0, img_w)
        return yl, yh, xl, xh

    def cutmix_bbox_and_lam(self, img_shape, lam, count=None):
        """Generate bbox and apply lambda correction.

        Args:
            img_shape (tuple): Image shape as tuple
            lam (float): Cutmix lambda value
            count (int, optional): Number of bbox to generate. Default to None
        """
        if self.cutmix_minmax is not None:
            yl, yu, xl, xu = self.rand_bbox_minmax(img_shape, count=count)
        else:
            yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count)
        if self.correct_lam or self.cutmix_minmax is not None:
            bbox_area = (yu - yl) * (xu - xl)
            lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
        return (yl, yu, xl, xu), lam

    @abstractmethod
    def cutmix(self, imgs, gt_label):
        pass


@AUGMENT.register_module(name='BatchCutMix')
class BatchCutMixLayer(BaseCutMixLayer):
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    r"""CutMix layer for a batch of data.

    CutMix is a method to improve the network's generalization capability. It's
    proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
    with Localizable Features <https://arxiv.org/abs/1905.04899>`

    With this method, patches are cut and pasted among training images where
    the ground truth labels are also mixed proportionally to the area of the
    patches.

    Args:
        alpha (float): Parameters for Beta distribution to generate the
            mixing ratio. It should be a positive number. More details
            can be found in :class:`BatchMixupLayer`.
        num_classes (int): The number of classes
        prob (float): The probability to execute cutmix. It should be in
            range [0, 1]. Defaults to 1.0.
        cutmix_minmax (List[float], optional): The min/max area ratio of the
            patches. If not None, the bounding-box of patches is uniform
            sampled within this ratio range, and the ``alpha`` will be ignored.
            Otherwise, the bounding-box is generated according to the
            ``alpha``. Defaults to None.
        correct_lam (bool): Whether to apply lambda correction when cutmix bbox
            clipped by image borders. Defaults to True.

    Note:
        If the ``cutmix_minmax`` is None, how to generate the bounding-box of
        patches according to the ``alpha``?

        First, generate a :math:`\lambda`, details can be found in
        :class:`BatchMixupLayer`. And then, the area ratio of the bounding-box
        is calculated by:

        .. math::
            \text{ratio} = \sqrt{1-\lambda}
    """
unknown's avatar
unknown committed
156
157
158
159
160

    def __init__(self, *args, **kwargs):
        super(BatchCutMixLayer, self).__init__(*args, **kwargs)

    def cutmix(self, img, gt_label):
161
        one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
unknown's avatar
unknown committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size = img.size(0)
        index = torch.randperm(batch_size)

        (bby1, bby2, bbx1,
         bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
        img[:, :, bby1:bby2, bbx1:bbx2] = \
            img[index, :, bby1:bby2, bbx1:bbx2]
        mixed_gt_label = lam * one_hot_gt_label + (
            1 - lam) * one_hot_gt_label[index, :]
        return img, mixed_gt_label

    def __call__(self, img, gt_label):
        return self.cutmix(img, gt_label)