# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torchvision import torch.nn.functional as F from torch import nn, Tensor from torchvision.ops import boxes as box_ops from torchvision.ops import roi_align from typing import Optional, List, Dict, Tuple from model.utils import BoxCoder, Matcher def expand_boxes(boxes, scale): # type: (Tensor, float) -> Tensor w_half = (boxes[:, 2] - boxes[:, 0]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * .5 x_c = (boxes[:, 2] + boxes[:, 0]) * .5 y_c = (boxes[:, 3] + boxes[:, 1]) * .5 w_half *= scale h_half *= scale boxes_exp = torch.zeros_like(boxes) boxes_exp[:, 0] = x_c - w_half boxes_exp[:, 2] = x_c + w_half boxes_exp[:, 1] = y_c - h_half boxes_exp[:, 3] = y_c + h_half return boxes_exp def expand_masks(mask, padding): # type: (Tensor, int) -> Tuple[Tensor, float] M = mask.shape[-1] scale = float(M + 2 * padding) / M padded_mask = F.pad(mask, (padding,) * 4) return padded_mask, scale def paste_mask_in_image(mask, box, im_h, im_w): # type: (Tensor, Tensor, int, int) -> Tensor TO_REMOVE = 1 w = int(box[2] - box[0] + TO_REMOVE) h = int(box[3] - box[1] + TO_REMOVE) w = max(w, 1) h = max(h, 1) # Set shape to [batchxCxHxW] mask = mask.expand((1, 1, -1, -1)) # Resize mask mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = mask[0][0] im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) x_0 = max(box[0], 0) x_1 = min(box[2] + 1, im_w) y_0 = max(box[1], 0) y_1 = min(box[3] + 1, im_h) im_mask[y_0:y_1, x_0:x_1] = mask[ (y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0]) ] return im_mask def paste_masks_in_image(masks, boxes, img_shape, padding=1): # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor masks, scale = expand_masks(masks, padding=padding) boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) im_h, im_w = img_shape res = [ paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes) ] if len(res) > 0: ret = torch.stack(res, dim=0)[:, None] else: ret = masks.new_empty((0, 1, im_h, im_w)) return ret