transform.py 11.9 KB
Newer Older
1
import math
limm's avatar
limm committed
2
3
from typing import Any, Dict, List, Optional, Tuple

4
import torch
5
import torchvision
6
from torch import nn, Tensor
7
8
9
10
11

from .image_list import ImageList
from .roi_heads import paste_masks_in_image


12
@torch.jit.unused
limm's avatar
limm committed
13
def _get_shape_onnx(image: Tensor) -> Tensor:
14
    from torch.onnx import operators
limm's avatar
limm committed
15

16
    return operators.shape_as_tensor(image)[-2:]
17
18


19
@torch.jit.unused
limm's avatar
limm committed
20
def _fake_cast_onnx(v: Tensor) -> float:
21
22
    # ONNX requires a tensor but here we fake its type for JIT.
    return v
23
24


limm's avatar
limm committed
25
26
27
28
29
30
31
def _resize_image_and_masks(
    image: Tensor,
    self_min_size: int,
    self_max_size: int,
    target: Optional[Dict[str, Tensor]] = None,
    fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
32
33
    if torchvision._is_tracing():
        im_shape = _get_shape_onnx(image)
limm's avatar
limm committed
34
    elif torch.jit.is_scripting():
35
        im_shape = torch.tensor(image.shape[-2:])
limm's avatar
limm committed
36
37
    else:
        im_shape = image.shape[-2:]
38

39
40
41
42
43
    size: Optional[List[int]] = None
    scale_factor: Optional[float] = None
    recompute_scale_factor: Optional[bool] = None
    if fixed_size is not None:
        size = [fixed_size[1], fixed_size[0]]
44
    else:
limm's avatar
limm committed
45
46
47
48
49
50
51
52
53
54
55
        if torch.jit.is_scripting() or torchvision._is_tracing():
            min_size = torch.min(im_shape).to(dtype=torch.float32)
            max_size = torch.max(im_shape).to(dtype=torch.float32)
            self_min_size_f = float(self_min_size)
            self_max_size_f = float(self_max_size)
            scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)

            if torchvision._is_tracing():
                scale_factor = _fake_cast_onnx(scale)
            else:
                scale_factor = scale.item()
56
57

        else:
limm's avatar
limm committed
58
59
60
61
62
            # Do it the normal way
            min_size = min(im_shape)
            max_size = max(im_shape)
            scale_factor = min(self_min_size / min_size, self_max_size / max_size)

63
        recompute_scale_factor = True
64

limm's avatar
limm committed
65
66
67
68
69
70
71
72
    image = torch.nn.functional.interpolate(
        image[None],
        size=size,
        scale_factor=scale_factor,
        mode="bilinear",
        recompute_scale_factor=recompute_scale_factor,
        align_corners=False,
    )[0]
73
74
75
76
77
78

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
limm's avatar
limm committed
79
80
81
        mask = torch.nn.functional.interpolate(
            mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
        )[:, 0].byte()
82
83
84
85
        target["masks"] = mask
    return image, target


86
class GeneralizedRCNNTransform(nn.Module):
87
88
89
90
    """
    Performs input / target transformation before feeding the data to a GeneralizedRCNN
    model.

limm's avatar
limm committed
91
    The transformations it performs are:
92
93
94
95
96
97
        - input normalization (mean subtraction and std division)
        - input / target resizing to match min_size / max_size

    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
    """

limm's avatar
limm committed
98
99
100
101
102
103
104
105
106
107
108
    def __init__(
        self,
        min_size: int,
        max_size: int,
        image_mean: List[float],
        image_std: List[float],
        size_divisible: int = 32,
        fixed_size: Optional[Tuple[int, int]] = None,
        **kwargs: Any,
    ):
        super().__init__()
109
110
111
112
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size
113
114
        self.image_mean = image_mean
        self.image_std = image_std
115
116
        self.size_divisible = size_divisible
        self.fixed_size = fixed_size
limm's avatar
limm committed
117
        self._skip_resize = kwargs.pop("_skip_resize", False)
118

limm's avatar
limm committed
119
120
121
    def forward(
        self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
    ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
122
        images = [img for img in images]
123
124
125
        if targets is not None:
            # make a copy of targets to avoid modifying it in-place
            # once torchscript supports dict comprehension
urmi22's avatar
urmi22 committed
126
            # this can be simplified as follows
127
128
129
130
131
132
133
134
            # targets = [{k: v for k,v in t.items()} for t in targets]
            targets_copy: List[Dict[str, Tensor]] = []
            for t in targets:
                data: Dict[str, Tensor] = {}
                for k, v in t.items():
                    data[k] = v
                targets_copy.append(data)
            targets = targets_copy
135
136
        for i in range(len(images)):
            image = images[i]
eellison's avatar
eellison committed
137
138
            target_index = targets[i] if targets is not None else None

139
            if image.dim() != 3:
limm's avatar
limm committed
140
                raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
141
            image = self.normalize(image)
eellison's avatar
eellison committed
142
            image, target_index = self.resize(image, target_index)
143
            images[i] = image
eellison's avatar
eellison committed
144
145
            if targets is not None and target_index is not None:
                targets[i] = target_index
146

147
        image_sizes = [img.shape[-2:] for img in images]
148
        images = self.batch_images(images, size_divisible=self.size_divisible)
149
        image_sizes_list: List[Tuple[int, int]] = []
eellison's avatar
eellison committed
150
        for image_size in image_sizes:
limm's avatar
limm committed
151
152
153
154
            torch._assert(
                len(image_size) == 2,
                f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
            )
eellison's avatar
eellison committed
155
156
157
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
158
159
        return image_list, targets

limm's avatar
limm committed
160
    def normalize(self, image: Tensor) -> Tensor:
161
162
163
164
165
        if not image.is_floating_point():
            raise TypeError(
                f"Expected input images to be of floating type (in range [0, 1]), "
                f"but found type {image.dtype} instead"
            )
166
167
168
169
170
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

limm's avatar
limm committed
171
    def torch_choice(self, k: List[int]) -> int:
eellison's avatar
eellison committed
172
        """
limm's avatar
limm committed
173
174
        Implements `random.choice` via torch ops, so it can be compiled with
        TorchScript and we use PyTorch's RNG (not native RNG)
eellison's avatar
eellison committed
175
        """
limm's avatar
limm committed
176
        index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
Francisco Massa's avatar
Francisco Massa committed
177
        return k[index]
eellison's avatar
eellison committed
178

limm's avatar
limm committed
179
180
181
182
183
    def resize(
        self,
        image: Tensor,
        target: Optional[Dict[str, Tensor]] = None,
    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
184
        h, w = image.shape[-2:]
185
        if self.training:
limm's avatar
limm committed
186
187
188
            if self._skip_resize:
                return image, target
            size = self.torch_choice(self.min_size)
189
        else:
limm's avatar
limm committed
190
191
            size = self.min_size[-1]
        image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
192
193
194
195
196
197
198
199
200
201
202
203
204
205

        if target is None:
            return image, target

        bbox = target["boxes"]
        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
        target["boxes"] = bbox

        if "keypoints" in target:
            keypoints = target["keypoints"]
            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
            target["keypoints"] = keypoints
        return image, target

206
207
    # _onnx_batch_images() is an implementation of
    # batch_images() that is supported by ONNX tracing.
eellison's avatar
eellison committed
208
    @torch.jit.unused
limm's avatar
limm committed
209
    def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        max_size = []
        for i in range(images[0].dim()):
            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
            max_size.append(max_size_i)
        stride = size_divisible
        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size = tuple(max_size)

        # work around for
        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
        # which is not yet supported in onnx
        padded_imgs = []
        for img in images:
            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
225
            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
226
227
228
229
            padded_imgs.append(padded_img)

        return torch.stack(padded_imgs)

limm's avatar
limm committed
230
    def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
eellison's avatar
eellison committed
231
232
233
234
235
236
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

limm's avatar
limm committed
237
    def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
238
239
240
241
        if torchvision._is_tracing():
            # batch_images() does not export well to ONNX
            # call _onnx_batch_images() instead
            return self._onnx_batch_images(images, size_divisible)
242

eellison's avatar
eellison committed
243
244
        max_size = self.max_by_axis([list(img.shape) for img in images])
        stride = float(size_divisible)
245
        max_size = list(max_size)
246
247
        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
248

eellison's avatar
eellison committed
249
250
        batch_shape = [len(images)] + max_size
        batched_imgs = images[0].new_full(batch_shape, 0)
limm's avatar
limm committed
251
252
253
        for i in range(batched_imgs.shape[0]):
            img = images[i]
            batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
254
255
256

        return batched_imgs

limm's avatar
limm committed
257
258
259
260
261
262
    def postprocess(
        self,
        result: List[Dict[str, Tensor]],
        image_shapes: List[Tuple[int, int]],
        original_image_sizes: List[Tuple[int, int]],
    ) -> List[Dict[str, Tensor]]:
263
264
265
266
267
268
        if self.training:
            return result
        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
            boxes = pred["boxes"]
            boxes = resize_boxes(boxes, im_s, o_im_s)
            result[i]["boxes"] = boxes
269
270
            if "masks" in pred:
                masks = pred["masks"]
271
                masks = paste_masks_in_image(masks, boxes, o_im_s)
272
                result[i]["masks"] = masks
273
274
275
276
277
278
            if "keypoints" in pred:
                keypoints = pred["keypoints"]
                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                result[i]["keypoints"] = keypoints
        return result

limm's avatar
limm committed
279
280
281
282
283
284
    def __repr__(self) -> str:
        format_string = f"{self.__class__.__name__}("
        _indent = "\n    "
        format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
        format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
        format_string += "\n)"
285
286
        return format_string

287

limm's avatar
limm committed
288
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
289
    ratios = [
limm's avatar
limm committed
290
291
        torch.tensor(s, dtype=torch.float32, device=keypoints.device)
        / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
292
293
        for s, s_orig in zip(new_size, original_size)
    ]
294
295
    ratio_h, ratio_w = ratios
    resized_data = keypoints.clone()
296
297
298
299
300
301
302
    if torch._C._get_tracing_state():
        resized_data_0 = resized_data[:, :, 0] * ratio_w
        resized_data_1 = resized_data[:, :, 1] * ratio_h
        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
    else:
        resized_data[..., 0] *= ratio_w
        resized_data[..., 1] *= ratio_h
303
304
305
    return resized_data


limm's avatar
limm committed
306
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
307
    ratios = [
limm's avatar
limm committed
308
309
        torch.tensor(s, dtype=torch.float32, device=boxes.device)
        / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
310
311
        for s, s_orig in zip(new_size, original_size)
    ]
312
313
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)
314

315
316
317
318
319
    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)