transform.py 10.4 KB
Newer Older
1
import random
2
3
import math
import torch
eellison's avatar
eellison committed
4
from torch import nn, Tensor
5
import torchvision
eellison's avatar
eellison committed
6
from torch.jit.annotations import List, Tuple, Dict, Optional
7
8
9
10
11
12

from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
from .roi_heads import paste_masks_in_image


13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@torch.jit.unused
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
    # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
    from torch.onnx import operators
    im_shape = operators.shape_as_tensor(image)[-2:]
    min_size = torch.min(im_shape).to(dtype=torch.float32)
    max_size = torch.max(im_shape).to(dtype=torch.float32)
    scale_factor = self_min_size / min_size
    if max_size * scale_factor > self_max_size:
        scale_factor = self_max_size / max_size

    image = torch.nn.functional.interpolate(
        image[None], scale_factor=scale_factor, mode='bilinear',
        align_corners=False)[0]

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
        mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
        target["masks"] = mask
    return image, target


def _resize_image_and_masks(image, self_min_size, self_max_size, target):
    # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
    im_shape = torch.tensor(image.shape[-2:])
    min_size = float(torch.min(im_shape))
    max_size = float(torch.max(im_shape))
    scale_factor = self_min_size / min_size
    if max_size * scale_factor > self_max_size:
        scale_factor = self_max_size / max_size
    image = torch.nn.functional.interpolate(
        image[None], scale_factor=scale_factor, mode='bilinear',
        align_corners=False)[0]

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
        mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
        target["masks"] = mask
    return image, target


60
class GeneralizedRCNNTransform(nn.Module):
61
62
63
64
65
66
67
68
69
70
71
    """
    Performs input / target transformation before feeding the data to a GeneralizedRCNN
    model.

    The transformations it perform are:
        - input normalization (mean subtraction and std division)
        - input / target resizing to match min_size / max_size

    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
    """

72
73
    def __init__(self, min_size, max_size, image_mean, image_std):
        super(GeneralizedRCNNTransform, self).__init__()
74
75
76
77
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size
78
79
80
81
        self.image_mean = image_mean
        self.image_std = image_std

    def forward(self, images, targets=None):
eellison's avatar
eellison committed
82
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
83
        images = [img for img in images]
84
85
        for i in range(len(images)):
            image = images[i]
eellison's avatar
eellison committed
86
87
            target_index = targets[i] if targets is not None else None

88
89
90
91
            if image.dim() != 3:
                raise ValueError("images is expected to be a list of 3d tensors "
                                 "of shape [C, H, W], got {}".format(image.shape))
            image = self.normalize(image)
eellison's avatar
eellison committed
92
            image, target_index = self.resize(image, target_index)
93
            images[i] = image
eellison's avatar
eellison committed
94
95
            if targets is not None and target_index is not None:
                targets[i] = target_index
96

97
98
        image_sizes = [img.shape[-2:] for img in images]
        images = self.batch_images(images)
eellison's avatar
eellison committed
99
100
101
102
103
104
        image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
        for image_size in image_sizes:
            assert len(image_size) == 2
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
105
106
107
108
109
110
111
112
        return image_list, targets

    def normalize(self, image):
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

eellison's avatar
eellison committed
113
114
115
116
117
118
119
120
121
122
    def torch_choice(self, l):
        # type: (List[int])
        """
        Implements `random.choice` via torch ops so it can be compiled with
        TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
        is fixed.
        """
        index = int(torch.empty(1).uniform_(0., float(len(l))).item())
        return l[index]

123
    def resize(self, image, target):
eellison's avatar
eellison committed
124
        # type: (Tensor, Optional[Dict[str, Tensor]])
125
        h, w = image.shape[-2:]
126
        if self.training:
eellison's avatar
eellison committed
127
            size = float(self.torch_choice(self.min_size))
128
129
        else:
            # FIXME assume for now that testing uses the largest scale
eellison's avatar
eellison committed
130
            size = float(self.min_size[-1])
131
132
133
134
        if torchvision._is_tracing():
            image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
        else:
            image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
135
136
137
138
139
140
141
142
143
144
145
146
147
148

        if target is None:
            return image, target

        bbox = target["boxes"]
        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
        target["boxes"] = bbox

        if "keypoints" in target:
            keypoints = target["keypoints"]
            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
            target["keypoints"] = keypoints
        return image, target

149
150
    # _onnx_batch_images() is an implementation of
    # batch_images() that is supported by ONNX tracing.
eellison's avatar
eellison committed
151
    @torch.jit.unused
152
    def _onnx_batch_images(self, images, size_divisible=32):
eellison's avatar
eellison committed
153
        # type: (List[Tensor], int) -> Tensor
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        max_size = []
        for i in range(images[0].dim()):
            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
            max_size.append(max_size_i)
        stride = size_divisible
        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size = tuple(max_size)

        # work around for
        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
        # which is not yet supported in onnx
        padded_imgs = []
        for img in images:
            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
169
            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
170
171
172
173
            padded_imgs.append(padded_img)

        return torch.stack(padded_imgs)

eellison's avatar
eellison committed
174
175
176
177
178
179
180
181
    def max_by_axis(self, the_list):
        # type: (List[List[int]]) -> List[int]
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

182
    def batch_images(self, images, size_divisible=32):
eellison's avatar
eellison committed
183
        # type: (List[Tensor], int)
184
185
186
187
        if torchvision._is_tracing():
            # batch_images() does not export well to ONNX
            # call _onnx_batch_images() instead
            return self._onnx_batch_images(images, size_divisible)
188

eellison's avatar
eellison committed
189
190
        max_size = self.max_by_axis([list(img.shape) for img in images])
        stride = float(size_divisible)
191
        max_size = list(max_size)
192
193
        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
194

eellison's avatar
eellison committed
195
196
        batch_shape = [len(images)] + max_size
        batched_imgs = images[0].new_full(batch_shape, 0)
197
198
199
200
201
202
        for img, pad_img in zip(images, batched_imgs):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)

        return batched_imgs

    def postprocess(self, result, image_shapes, original_image_sizes):
eellison's avatar
eellison committed
203
        # type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]])
204
205
206
207
208
209
        if self.training:
            return result
        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
            boxes = pred["boxes"]
            boxes = resize_boxes(boxes, im_s, o_im_s)
            result[i]["boxes"] = boxes
210
211
            if "masks" in pred:
                masks = pred["masks"]
212
                masks = paste_masks_in_image(masks, boxes, o_im_s)
213
                result[i]["masks"] = masks
214
215
216
217
218
219
            if "keypoints" in pred:
                keypoints = pred["keypoints"]
                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                result[i]["keypoints"] = keypoints
        return result

220
221
222
223
224
225
226
227
228
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        _indent = '\n    '
        format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
        format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
                                                                                         self.max_size)
        format_string += '\n)'
        return format_string

229
230

def resize_keypoints(keypoints, original_size, new_size):
eellison's avatar
eellison committed
231
    # type: (Tensor, List[int], List[int])
232
233
234
235
236
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
        for s, s_orig in zip(new_size, original_size)
    ]
237
238
    ratio_h, ratio_w = ratios
    resized_data = keypoints.clone()
239
240
241
242
243
244
245
    if torch._C._get_tracing_state():
        resized_data_0 = resized_data[:, :, 0] * ratio_w
        resized_data_1 = resized_data[:, :, 1] * ratio_h
        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
    else:
        resized_data[..., 0] *= ratio_w
        resized_data[..., 1] *= ratio_h
246
247
248
249
    return resized_data


def resize_boxes(boxes, original_size, new_size):
eellison's avatar
eellison committed
250
    # type: (Tensor, List[int], List[int])
251
252
253
254
255
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=boxes.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
        for s, s_orig in zip(new_size, original_size)
    ]
256
257
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)
258

259
260
261
262
263
    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)