transform.py 11.3 KB
Newer Older
1
import random
2
3
import math
import torch
eellison's avatar
eellison committed
4
from torch import nn, Tensor
5
import torchvision
eellison's avatar
eellison committed
6
from torch.jit.annotations import List, Tuple, Dict, Optional
7
8
9
10
11
12

from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
from .roi_heads import paste_masks_in_image


13
14
15
16
17
18
19
@torch.jit.unused
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
    # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
    from torch.onnx import operators
    im_shape = operators.shape_as_tensor(image)[-2:]
    min_size = torch.min(im_shape).to(dtype=torch.float32)
    max_size = torch.max(im_shape).to(dtype=torch.float32)
20
    scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    image = torch.nn.functional.interpolate(
        image[None], scale_factor=scale_factor, mode='bilinear',
        align_corners=False)[0]

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
        mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
        target["masks"] = mask
    return image, target


def _resize_image_and_masks(image, self_min_size, self_max_size, target):
    # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
    im_shape = torch.tensor(image.shape[-2:])
    min_size = float(torch.min(im_shape))
    max_size = float(torch.max(im_shape))
    scale_factor = self_min_size / min_size
    if max_size * scale_factor > self_max_size:
        scale_factor = self_max_size / max_size
    image = torch.nn.functional.interpolate(
        image[None], scale_factor=scale_factor, mode='bilinear',
        align_corners=False)[0]

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
        mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
        target["masks"] = mask
    return image, target


58
class GeneralizedRCNNTransform(nn.Module):
59
60
61
62
63
64
65
66
67
68
69
    """
    Performs input / target transformation before feeding the data to a GeneralizedRCNN
    model.

    The transformations it perform are:
        - input normalization (mean subtraction and std division)
        - input / target resizing to match min_size / max_size

    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
    """

70
71
    def __init__(self, min_size, max_size, image_mean, image_std):
        super(GeneralizedRCNNTransform, self).__init__()
72
73
74
75
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size
76
77
78
        self.image_mean = image_mean
        self.image_std = image_std

79
80
81
82
83
    def forward(self,
                images,       # type: List[Tensor]
                targets=None  # type: Optional[List[Dict[str, Tensor]]]
                ):
        # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
84
        images = [img for img in images]
85
86
87
88
89
90
91
92
93
94
95
96
        if targets is not None:
            # make a copy of targets to avoid modifying it in-place
            # once torchscript supports dict comprehension
            # this can be simplified as as follows
            # targets = [{k: v for k,v in t.items()} for t in targets]
            targets_copy: List[Dict[str, Tensor]] = []
            for t in targets:
                data: Dict[str, Tensor] = {}
                for k, v in t.items():
                    data[k] = v
                targets_copy.append(data)
            targets = targets_copy
97
98
        for i in range(len(images)):
            image = images[i]
eellison's avatar
eellison committed
99
100
            target_index = targets[i] if targets is not None else None

101
102
103
104
            if image.dim() != 3:
                raise ValueError("images is expected to be a list of 3d tensors "
                                 "of shape [C, H, W], got {}".format(image.shape))
            image = self.normalize(image)
eellison's avatar
eellison committed
105
            image, target_index = self.resize(image, target_index)
106
            images[i] = image
eellison's avatar
eellison committed
107
108
            if targets is not None and target_index is not None:
                targets[i] = target_index
109

110
111
        image_sizes = [img.shape[-2:] for img in images]
        images = self.batch_images(images)
eellison's avatar
eellison committed
112
113
114
115
116
117
        image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
        for image_size in image_sizes:
            assert len(image_size) == 2
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
118
119
120
121
122
123
124
125
        return image_list, targets

    def normalize(self, image):
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

Francisco Massa's avatar
Francisco Massa committed
126
    def torch_choice(self, k):
127
        # type: (List[int]) -> int
eellison's avatar
eellison committed
128
129
130
131
132
        """
        Implements `random.choice` via torch ops so it can be compiled with
        TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
        is fixed.
        """
Francisco Massa's avatar
Francisco Massa committed
133
134
        index = int(torch.empty(1).uniform_(0., float(len(k))).item())
        return k[index]
eellison's avatar
eellison committed
135

136
    def resize(self, image, target):
137
        # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
138
        h, w = image.shape[-2:]
139
        if self.training:
eellison's avatar
eellison committed
140
            size = float(self.torch_choice(self.min_size))
141
142
        else:
            # FIXME assume for now that testing uses the largest scale
eellison's avatar
eellison committed
143
            size = float(self.min_size[-1])
144
145
146
147
        if torchvision._is_tracing():
            image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
        else:
            image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        if target is None:
            return image, target

        bbox = target["boxes"]
        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
        target["boxes"] = bbox

        if "keypoints" in target:
            keypoints = target["keypoints"]
            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
            target["keypoints"] = keypoints
        return image, target

162
163
    # _onnx_batch_images() is an implementation of
    # batch_images() that is supported by ONNX tracing.
eellison's avatar
eellison committed
164
    @torch.jit.unused
165
    def _onnx_batch_images(self, images, size_divisible=32):
eellison's avatar
eellison committed
166
        # type: (List[Tensor], int) -> Tensor
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        max_size = []
        for i in range(images[0].dim()):
            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
            max_size.append(max_size_i)
        stride = size_divisible
        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size = tuple(max_size)

        # work around for
        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
        # which is not yet supported in onnx
        padded_imgs = []
        for img in images:
            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
182
            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
183
184
185
186
            padded_imgs.append(padded_img)

        return torch.stack(padded_imgs)

eellison's avatar
eellison committed
187
188
189
190
191
192
193
194
    def max_by_axis(self, the_list):
        # type: (List[List[int]]) -> List[int]
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

195
    def batch_images(self, images, size_divisible=32):
196
        # type: (List[Tensor], int) -> Tensor
197
198
199
200
        if torchvision._is_tracing():
            # batch_images() does not export well to ONNX
            # call _onnx_batch_images() instead
            return self._onnx_batch_images(images, size_divisible)
201

eellison's avatar
eellison committed
202
203
        max_size = self.max_by_axis([list(img.shape) for img in images])
        stride = float(size_divisible)
204
        max_size = list(max_size)
205
206
        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
207

eellison's avatar
eellison committed
208
209
        batch_shape = [len(images)] + max_size
        batched_imgs = images[0].new_full(batch_shape, 0)
210
211
212
213
214
        for img, pad_img in zip(images, batched_imgs):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)

        return batched_imgs

215
216
217
218
219
220
    def postprocess(self,
                    result,               # type: List[Dict[str, Tensor]]
                    image_shapes,         # type: List[Tuple[int, int]]
                    original_image_sizes  # type: List[Tuple[int, int]]
                    ):
        # type: (...) -> List[Dict[str, Tensor]]
221
222
223
224
225
226
        if self.training:
            return result
        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
            boxes = pred["boxes"]
            boxes = resize_boxes(boxes, im_s, o_im_s)
            result[i]["boxes"] = boxes
227
228
            if "masks" in pred:
                masks = pred["masks"]
229
                masks = paste_masks_in_image(masks, boxes, o_im_s)
230
                result[i]["masks"] = masks
231
232
233
234
235
236
            if "keypoints" in pred:
                keypoints = pred["keypoints"]
                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                result[i]["keypoints"] = keypoints
        return result

237
238
239
240
241
242
243
244
245
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        _indent = '\n    '
        format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
        format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
                                                                                         self.max_size)
        format_string += '\n)'
        return format_string

246
247

def resize_keypoints(keypoints, original_size, new_size):
248
    # type: (Tensor, List[int], List[int]) -> Tensor
249
250
251
252
253
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
        for s, s_orig in zip(new_size, original_size)
    ]
254
255
    ratio_h, ratio_w = ratios
    resized_data = keypoints.clone()
256
257
258
259
260
261
262
    if torch._C._get_tracing_state():
        resized_data_0 = resized_data[:, :, 0] * ratio_w
        resized_data_1 = resized_data[:, :, 1] * ratio_h
        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
    else:
        resized_data[..., 0] *= ratio_w
        resized_data[..., 1] *= ratio_h
263
264
265
266
    return resized_data


def resize_boxes(boxes, original_size, new_size):
267
    # type: (Tensor, List[int], List[int]) -> Tensor
268
269
270
271
272
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=boxes.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
        for s, s_orig in zip(new_size, original_size)
    ]
273
274
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)
275

276
277
278
279
280
    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)