transform.py 11.6 KB
Newer Older
1
2
import math
import torch
3
import torchvision
4
5

from torch import nn, Tensor
6
from typing import List, Tuple, Dict, Optional
7
8
9
10
11

from .image_list import ImageList
from .roi_heads import paste_masks_in_image


12
@torch.jit.unused
13
def _get_shape_onnx(image: Tensor) -> Tensor:
14
    from torch.onnx import operators
15
    return operators.shape_as_tensor(image)[-2:]
16
17


18
@torch.jit.unused
19
def _fake_cast_onnx(v: Tensor) -> float:
20
21
    # ONNX requires a tensor but here we fake its type for JIT.
    return v
22
23


24
def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: float,
25
26
27
                            target: Optional[Dict[str, Tensor]] = None,
                            fixed_size: Optional[Tuple[int, int]] = None,
                            ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
28
29
30
31
32
    if torchvision._is_tracing():
        im_shape = _get_shape_onnx(image)
    else:
        im_shape = torch.tensor(image.shape[-2:])

33
34
35
36
37
    size: Optional[List[int]] = None
    scale_factor: Optional[float] = None
    recompute_scale_factor: Optional[bool] = None
    if fixed_size is not None:
        size = [fixed_size[1], fixed_size[0]]
38
    else:
39
40
41
42
43
44
45
46
47
        min_size = torch.min(im_shape).to(dtype=torch.float32)
        max_size = torch.max(im_shape).to(dtype=torch.float32)
        scale = torch.min(self_min_size / min_size, self_max_size / max_size)

        if torchvision._is_tracing():
            scale_factor = _fake_cast_onnx(scale)
        else:
            scale_factor = scale.item()
        recompute_scale_factor = True
48

49
50
    image = torch.nn.functional.interpolate(image[None], size=size, scale_factor=scale_factor, mode='bilinear',
                                            recompute_scale_factor=recompute_scale_factor, align_corners=False)[0]
51
52
53
54
55
56

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
57
58
        mask = torch.nn.functional.interpolate(mask[:, None].float(), size=size, scale_factor=scale_factor,
                                               recompute_scale_factor=recompute_scale_factor)[:, 0].byte()
59
60
61
62
        target["masks"] = mask
    return image, target


63
class GeneralizedRCNNTransform(nn.Module):
64
65
66
67
68
69
70
71
72
73
74
    """
    Performs input / target transformation before feeding the data to a GeneralizedRCNN
    model.

    The transformations it perform are:
        - input normalization (mean subtraction and std division)
        - input / target resizing to match min_size / max_size

    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
    """

75
76
    def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float],
                 size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None):
77
        super(GeneralizedRCNNTransform, self).__init__()
78
79
80
81
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size
82
83
        self.image_mean = image_mean
        self.image_std = image_std
84
85
        self.size_divisible = size_divisible
        self.fixed_size = fixed_size
86

87
    def forward(self,
88
89
90
                images: List[Tensor],
                targets: Optional[List[Dict[str, Tensor]]] = None
                ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
91
        images = [img for img in images]
92
93
94
        if targets is not None:
            # make a copy of targets to avoid modifying it in-place
            # once torchscript supports dict comprehension
urmi22's avatar
urmi22 committed
95
            # this can be simplified as follows
96
97
98
99
100
101
102
103
            # targets = [{k: v for k,v in t.items()} for t in targets]
            targets_copy: List[Dict[str, Tensor]] = []
            for t in targets:
                data: Dict[str, Tensor] = {}
                for k, v in t.items():
                    data[k] = v
                targets_copy.append(data)
            targets = targets_copy
104
105
        for i in range(len(images)):
            image = images[i]
eellison's avatar
eellison committed
106
107
            target_index = targets[i] if targets is not None else None

108
109
110
111
            if image.dim() != 3:
                raise ValueError("images is expected to be a list of 3d tensors "
                                 "of shape [C, H, W], got {}".format(image.shape))
            image = self.normalize(image)
eellison's avatar
eellison committed
112
            image, target_index = self.resize(image, target_index)
113
            images[i] = image
eellison's avatar
eellison committed
114
115
            if targets is not None and target_index is not None:
                targets[i] = target_index
116

117
        image_sizes = [img.shape[-2:] for img in images]
118
        images = self.batch_images(images, size_divisible=self.size_divisible)
119
        image_sizes_list: List[Tuple[int, int]] = []
eellison's avatar
eellison committed
120
121
122
123
124
        for image_size in image_sizes:
            assert len(image_size) == 2
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
125
126
        return image_list, targets

127
    def normalize(self, image: Tensor) -> Tensor:
128
129
130
131
132
        if not image.is_floating_point():
            raise TypeError(
                f"Expected input images to be of floating type (in range [0, 1]), "
                f"but found type {image.dtype} instead"
            )
133
134
135
136
137
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

138
    def torch_choice(self, k: List[int]) -> int:
eellison's avatar
eellison committed
139
140
141
142
143
        """
        Implements `random.choice` via torch ops so it can be compiled with
        TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
        is fixed.
        """
Francisco Massa's avatar
Francisco Massa committed
144
145
        index = int(torch.empty(1).uniform_(0., float(len(k))).item())
        return k[index]
eellison's avatar
eellison committed
146

147
148
149
150
    def resize(self,
               image: Tensor,
               target: Optional[Dict[str, Tensor]] = None,
               ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
151
        h, w = image.shape[-2:]
152
        if self.training:
eellison's avatar
eellison committed
153
            size = float(self.torch_choice(self.min_size))
154
155
        else:
            # FIXME assume for now that testing uses the largest scale
eellison's avatar
eellison committed
156
            size = float(self.min_size[-1])
157
        image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
158
159
160
161
162
163
164
165
166
167
168
169
170
171

        if target is None:
            return image, target

        bbox = target["boxes"]
        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
        target["boxes"] = bbox

        if "keypoints" in target:
            keypoints = target["keypoints"]
            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
            target["keypoints"] = keypoints
        return image, target

172
173
    # _onnx_batch_images() is an implementation of
    # batch_images() that is supported by ONNX tracing.
eellison's avatar
eellison committed
174
    @torch.jit.unused
175
    def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        max_size = []
        for i in range(images[0].dim()):
            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
            max_size.append(max_size_i)
        stride = size_divisible
        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size = tuple(max_size)

        # work around for
        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
        # which is not yet supported in onnx
        padded_imgs = []
        for img in images:
            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
191
            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
192
193
194
195
            padded_imgs.append(padded_img)

        return torch.stack(padded_imgs)

196
    def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
eellison's avatar
eellison committed
197
198
199
200
201
202
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

203
    def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
204
205
206
207
        if torchvision._is_tracing():
            # batch_images() does not export well to ONNX
            # call _onnx_batch_images() instead
            return self._onnx_batch_images(images, size_divisible)
208

eellison's avatar
eellison committed
209
210
        max_size = self.max_by_axis([list(img.shape) for img in images])
        stride = float(size_divisible)
211
        max_size = list(max_size)
212
213
        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
214

eellison's avatar
eellison committed
215
216
        batch_shape = [len(images)] + max_size
        batched_imgs = images[0].new_full(batch_shape, 0)
217
218
219
        for i in range(batched_imgs.shape[0]):
            img = images[i]
            batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
220
221
222

        return batched_imgs

223
    def postprocess(self,
224
225
226
227
                    result: List[Dict[str, Tensor]],
                    image_shapes: List[Tuple[int, int]],
                    original_image_sizes: List[Tuple[int, int]]
                    ) -> List[Dict[str, Tensor]]:
228
229
230
231
232
233
        if self.training:
            return result
        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
            boxes = pred["boxes"]
            boxes = resize_boxes(boxes, im_s, o_im_s)
            result[i]["boxes"] = boxes
234
235
            if "masks" in pred:
                masks = pred["masks"]
236
                masks = paste_masks_in_image(masks, boxes, o_im_s)
237
                result[i]["masks"] = masks
238
239
240
241
242
243
            if "keypoints" in pred:
                keypoints = pred["keypoints"]
                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                result[i]["keypoints"] = keypoints
        return result

244
    def __repr__(self) -> str:
245
246
247
248
249
250
251
252
        format_string = self.__class__.__name__ + '('
        _indent = '\n    '
        format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
        format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
                                                                                         self.max_size)
        format_string += '\n)'
        return format_string

253

254
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
255
256
257
258
259
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
        for s, s_orig in zip(new_size, original_size)
    ]
260
261
    ratio_h, ratio_w = ratios
    resized_data = keypoints.clone()
262
263
264
265
266
267
268
    if torch._C._get_tracing_state():
        resized_data_0 = resized_data[:, :, 0] * ratio_w
        resized_data_1 = resized_data[:, :, 1] * ratio_h
        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
    else:
        resized_data[..., 0] *= ratio_w
        resized_data[..., 1] *= ratio_h
269
270
271
    return resized_data


272
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
273
274
275
276
277
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=boxes.device) /
        torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
        for s, s_orig in zip(new_size, original_size)
    ]
278
279
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)
280

281
282
283
284
285
    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)