import math import torch from torch import nn from torchvision.ops import misc as misc_nn_ops from .image_list import ImageList from .roi_heads import paste_masks_in_image class GeneralizedRCNNTransform(nn.Module): def __init__(self, min_size, max_size, image_mean, image_std): super(GeneralizedRCNNTransform, self).__init__() self.min_size = float(min_size) self.max_size = float(max_size) self.image_mean = image_mean self.image_std = image_std def forward(self, images, targets=None): for i in range(len(images)): image = images[i] target = targets[i] if targets is not None else targets if image.dim() != 3: raise ValueError("images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape)) image = self.normalize(image) image, target = self.resize(image, target) images[i] = image if targets is not None: targets[i] = target image_sizes = [img.shape[-2:] for img in images] images = self.batch_images(images) image_list = ImageList(images, image_sizes) return image_list, targets def normalize(self, image): dtype, device = image.dtype, image.device mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device) return (image - mean[:, None, None]) / std[:, None, None] def resize(self, image, target): h, w = image.shape[-2:] min_size = min(image.shape[-2:]) max_size = max(image.shape[-2:]) scale_factor = self.min_size / min_size if max_size * scale_factor > self.max_size: scale_factor = self.max_size / max_size image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] if target is None: return image, target bbox = target["boxes"] bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) target["boxes"] = bbox if "masks" in target: mask = target["masks"] mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() target["masks"] = mask if "keypoints" in target: keypoints = target["keypoints"] keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) target["keypoints"] = keypoints return image, target def batch_images(self, images, size_divisible=32): # concatenate max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) stride = size_divisible max_size = list(max_size) max_size[1] = int(math.ceil(max_size[1] / stride) * stride) max_size[2] = int(math.ceil(max_size[2] / stride) * stride) max_size = tuple(max_size) batch_shape = (len(images),) + max_size batched_imgs = images[0].new(*batch_shape).zero_() for img, pad_img in zip(images, batched_imgs): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) return batched_imgs def postprocess(self, result, image_shapes, original_image_sizes): if self.training: return result for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): boxes = pred["boxes"] boxes = resize_boxes(boxes, im_s, o_im_s) result[i]["boxes"] = boxes if "mask" in pred: masks = pred["mask"] masks = paste_masks_in_image(masks, boxes, o_im_s) result[i]["mask"] = masks if "keypoints" in pred: keypoints = pred["keypoints"] keypoints = resize_keypoints(keypoints, im_s, o_im_s) result[i]["keypoints"] = keypoints return result def resize_keypoints(keypoints, original_size, new_size): ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)) ratio_h, ratio_w = ratios resized_data = keypoints.clone() resized_data[..., 0] *= ratio_w resized_data[..., 1] *= ratio_h return resized_data def resize_boxes(boxes, original_size, new_size): ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)) ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) xmin = xmin * ratio_width xmax = xmax * ratio_width ymin = ymin * ratio_height ymax = ymax * ratio_height return torch.stack((xmin, ymin, xmax, ymax), dim=1)