transform.py 11.3 KB
Newer Older
1
import math
2
3
from typing import List, Tuple, Dict, Optional

4
import torch
5
import torchvision
6
from torch import nn, Tensor
7
8
9
10
11

from .image_list import ImageList
from .roi_heads import paste_masks_in_image


12
@torch.jit.unused
13
def _get_shape_onnx(image: Tensor) -> Tensor:
14
    from torch.onnx import operators
15

16
    return operators.shape_as_tensor(image)[-2:]
17
18


19
@torch.jit.unused
20
def _fake_cast_onnx(v: Tensor) -> float:
21
22
    # ONNX requires a tensor but here we fake its type for JIT.
    return v
23
24


25
26
27
28
29
30
31
def _resize_image_and_masks(
    image: Tensor,
    self_min_size: float,
    self_max_size: float,
    target: Optional[Dict[str, Tensor]] = None,
    fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
32
33
34
35
36
    if torchvision._is_tracing():
        im_shape = _get_shape_onnx(image)
    else:
        im_shape = torch.tensor(image.shape[-2:])

37
38
39
40
41
    size: Optional[List[int]] = None
    scale_factor: Optional[float] = None
    recompute_scale_factor: Optional[bool] = None
    if fixed_size is not None:
        size = [fixed_size[1], fixed_size[0]]
42
    else:
43
44
45
46
47
48
49
50
51
        min_size = torch.min(im_shape).to(dtype=torch.float32)
        max_size = torch.max(im_shape).to(dtype=torch.float32)
        scale = torch.min(self_min_size / min_size, self_max_size / max_size)

        if torchvision._is_tracing():
            scale_factor = _fake_cast_onnx(scale)
        else:
            scale_factor = scale.item()
        recompute_scale_factor = True
52

53
54
55
56
57
58
59
60
    image = torch.nn.functional.interpolate(
        image[None],
        size=size,
        scale_factor=scale_factor,
        mode="bilinear",
        recompute_scale_factor=recompute_scale_factor,
        align_corners=False,
    )[0]
61
62
63
64
65
66

    if target is None:
        return image, target

    if "masks" in target:
        mask = target["masks"]
67
68
69
        mask = torch.nn.functional.interpolate(
            mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
        )[:, 0].byte()
70
71
72
73
        target["masks"] = mask
    return image, target


74
class GeneralizedRCNNTransform(nn.Module):
75
76
77
78
79
80
81
82
83
84
85
    """
    Performs input / target transformation before feeding the data to a GeneralizedRCNN
    model.

    The transformations it perform are:
        - input normalization (mean subtraction and std division)
        - input / target resizing to match min_size / max_size

    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
    """

86
87
88
89
90
91
92
93
94
    def __init__(
        self,
        min_size: int,
        max_size: int,
        image_mean: List[float],
        image_std: List[float],
        size_divisible: int = 32,
        fixed_size: Optional[Tuple[int, int]] = None,
    ):
95
        super().__init__()
96
97
98
99
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size
100
101
        self.image_mean = image_mean
        self.image_std = image_std
102
103
        self.size_divisible = size_divisible
        self.fixed_size = fixed_size
104

105
106
107
    def forward(
        self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
    ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
108
        images = [img for img in images]
109
110
111
        if targets is not None:
            # make a copy of targets to avoid modifying it in-place
            # once torchscript supports dict comprehension
urmi22's avatar
urmi22 committed
112
            # this can be simplified as follows
113
114
115
116
117
118
119
120
            # targets = [{k: v for k,v in t.items()} for t in targets]
            targets_copy: List[Dict[str, Tensor]] = []
            for t in targets:
                data: Dict[str, Tensor] = {}
                for k, v in t.items():
                    data[k] = v
                targets_copy.append(data)
            targets = targets_copy
121
122
        for i in range(len(images)):
            image = images[i]
eellison's avatar
eellison committed
123
124
            target_index = targets[i] if targets is not None else None

125
            if image.dim() != 3:
126
                raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
127
            image = self.normalize(image)
eellison's avatar
eellison committed
128
            image, target_index = self.resize(image, target_index)
129
            images[i] = image
eellison's avatar
eellison committed
130
131
            if targets is not None and target_index is not None:
                targets[i] = target_index
132

133
        image_sizes = [img.shape[-2:] for img in images]
134
        images = self.batch_images(images, size_divisible=self.size_divisible)
135
        image_sizes_list: List[Tuple[int, int]] = []
eellison's avatar
eellison committed
136
137
138
139
140
        for image_size in image_sizes:
            assert len(image_size) == 2
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
141
142
        return image_list, targets

143
    def normalize(self, image: Tensor) -> Tensor:
144
145
146
147
148
        if not image.is_floating_point():
            raise TypeError(
                f"Expected input images to be of floating type (in range [0, 1]), "
                f"but found type {image.dtype} instead"
            )
149
150
151
152
153
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

154
    def torch_choice(self, k: List[int]) -> int:
eellison's avatar
eellison committed
155
156
157
158
159
        """
        Implements `random.choice` via torch ops so it can be compiled with
        TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
        is fixed.
        """
160
        index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
Francisco Massa's avatar
Francisco Massa committed
161
        return k[index]
eellison's avatar
eellison committed
162

163
164
165
166
167
    def resize(
        self,
        image: Tensor,
        target: Optional[Dict[str, Tensor]] = None,
    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
168
        h, w = image.shape[-2:]
169
        if self.training:
eellison's avatar
eellison committed
170
            size = float(self.torch_choice(self.min_size))
171
172
        else:
            # FIXME assume for now that testing uses the largest scale
eellison's avatar
eellison committed
173
            size = float(self.min_size[-1])
174
        image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
175
176
177
178
179
180
181
182
183
184
185
186
187
188

        if target is None:
            return image, target

        bbox = target["boxes"]
        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
        target["boxes"] = bbox

        if "keypoints" in target:
            keypoints = target["keypoints"]
            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
            target["keypoints"] = keypoints
        return image, target

189
190
    # _onnx_batch_images() is an implementation of
    # batch_images() that is supported by ONNX tracing.
eellison's avatar
eellison committed
191
    @torch.jit.unused
192
    def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        max_size = []
        for i in range(images[0].dim()):
            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
            max_size.append(max_size_i)
        stride = size_divisible
        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
        max_size = tuple(max_size)

        # work around for
        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
        # which is not yet supported in onnx
        padded_imgs = []
        for img in images:
            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
208
            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
209
210
211
212
            padded_imgs.append(padded_img)

        return torch.stack(padded_imgs)

213
    def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
eellison's avatar
eellison committed
214
215
216
217
218
219
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

220
    def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
221
222
223
224
        if torchvision._is_tracing():
            # batch_images() does not export well to ONNX
            # call _onnx_batch_images() instead
            return self._onnx_batch_images(images, size_divisible)
225

eellison's avatar
eellison committed
226
227
        max_size = self.max_by_axis([list(img.shape) for img in images])
        stride = float(size_divisible)
228
        max_size = list(max_size)
229
230
        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
231

eellison's avatar
eellison committed
232
233
        batch_shape = [len(images)] + max_size
        batched_imgs = images[0].new_full(batch_shape, 0)
234
235
236
        for i in range(batched_imgs.shape[0]):
            img = images[i]
            batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
237
238
239

        return batched_imgs

240
241
242
243
244
245
    def postprocess(
        self,
        result: List[Dict[str, Tensor]],
        image_shapes: List[Tuple[int, int]],
        original_image_sizes: List[Tuple[int, int]],
    ) -> List[Dict[str, Tensor]]:
246
247
248
249
250
251
        if self.training:
            return result
        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
            boxes = pred["boxes"]
            boxes = resize_boxes(boxes, im_s, o_im_s)
            result[i]["boxes"] = boxes
252
253
            if "masks" in pred:
                masks = pred["masks"]
254
                masks = paste_masks_in_image(masks, boxes, o_im_s)
255
                result[i]["masks"] = masks
256
257
258
259
260
261
            if "keypoints" in pred:
                keypoints = pred["keypoints"]
                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                result[i]["keypoints"] = keypoints
        return result

262
    def __repr__(self) -> str:
Joao Gomes's avatar
Joao Gomes committed
263
        format_string = f"{self.__class__.__name__}("
264
        _indent = "\n    "
265
266
        format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
        format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
267
        format_string += "\n)"
268
269
        return format_string

270

271
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
272
    ratios = [
273
274
        torch.tensor(s, dtype=torch.float32, device=keypoints.device)
        / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
275
276
        for s, s_orig in zip(new_size, original_size)
    ]
277
278
    ratio_h, ratio_w = ratios
    resized_data = keypoints.clone()
279
280
281
282
283
284
285
    if torch._C._get_tracing_state():
        resized_data_0 = resized_data[:, :, 0] * ratio_w
        resized_data_1 = resized_data[:, :, 1] * ratio_h
        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
    else:
        resized_data[..., 0] *= ratio_w
        resized_data[..., 1] *= ratio_h
286
287
288
    return resized_data


289
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
290
    ratios = [
291
292
        torch.tensor(s, dtype=torch.float32, device=boxes.device)
        / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
293
294
        for s, s_orig in zip(new_size, original_size)
    ]
295
296
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)
297

298
299
300
301
302
    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)