"scripts/vscode:/vscode.git/clone" did not exist on "fb9582c4e15ecae392adfd4e1f4fedaf13ce835c"
Commit 111c27b0 authored by Kai Chen's avatar Kai Chen
Browse files

modify mask target computation

parent 830effcd
import torch import torch
import numpy as np import numpy as np
import mmcv
from .segms import polys_to_mask_wrt_box
def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
def mask_target(pos_proposals_list,
pos_assigned_gt_inds_list,
gt_polys_list,
img_meta,
cfg): cfg):
cfg_list = [cfg for _ in range(len(pos_proposals_list))] cfg_list = [cfg for _ in range(len(pos_proposals_list))]
mask_targets = map(mask_target_single, pos_proposals_list, mask_targets = map(mask_target_single, pos_proposals_list,
pos_assigned_gt_inds_list, gt_polys_list, img_meta, pos_assigned_gt_inds_list, gt_masks_list, cfg_list)
cfg_list) mask_targets = torch.cat(list(mask_targets))
mask_targets = torch.cat(tuple(mask_targets), dim=0)
return mask_targets return mask_targets
def mask_target_single(pos_proposals, def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
pos_assigned_gt_inds,
gt_polys,
img_meta,
cfg):
mask_size = cfg.mask_size mask_size = cfg.mask_size
num_pos = pos_proposals.size(0) num_pos = pos_proposals.size(0)
mask_targets = pos_proposals.new_zeros((num_pos, mask_size, mask_size)) mask_targets = []
if num_pos > 0: if num_pos > 0:
pos_proposals = pos_proposals.cpu().numpy() proposals_np = pos_proposals.cpu().numpy()
pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
scale_factor = img_meta['scale_factor']
for i in range(num_pos): for i in range(num_pos):
bbox = pos_proposals[i, :] / scale_factor gt_mask = gt_masks[pos_assigned_gt_inds[i]]
polys = gt_polys[pos_assigned_gt_inds[i]] bbox = proposals_np[i, :].astype(np.int32)
mask = polys_to_mask_wrt_box(polys, bbox, mask_size) x1, y1, x2, y2 = bbox
mask = np.array(mask > 0, dtype=np.float32) w = np.maximum(x2 - x1 + 1, 1)
mask_targets[i, ...] = torch.from_numpy(mask).to( h = np.maximum(y2 - y1 + 1, 1)
mask_targets.device) # mask is uint8 both before and after resizing
target = mmcv.imresize(gt_mask[y1:y1 + h, x1:x1 + w],
(mask_size, mask_size))
mask_targets.append(target)
mask_targets = torch.from_numpy(np.stack(mask_targets)).float().to(
pos_proposals.device)
else:
mask_targets = pos_proposals.new_zeros((0, mask_size, mask_size))
return mask_targets return mask_targets
...@@ -5,71 +5,12 @@ import numpy as np ...@@ -5,71 +5,12 @@ import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .transforms import (ImageTransform, BboxTransform, PolyMaskTransform, from .transforms import (ImageTransform, BboxTransform, MaskTransform,
Numpy2Tensor) Numpy2Tensor)
from .utils import to_tensor, show_ann, random_scale from .utils import to_tensor, show_ann, random_scale
from .utils import DataContainer as DC from .utils import DataContainer as DC
def parse_ann_info(ann_info, cat2label, with_mask=True):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
cat2label (dict): The mapping from category ids to labels.
with_mask (bool): Whether to parse mask annotations.
Returns:
tuple: gt_bboxes, gt_labels and gt_mask_info
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
# each mask consists of one or several polys, each poly is a list of float.
if with_mask:
gt_mask_polys = []
gt_poly_lens = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
if ann['iscrowd']:
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(cat2label[ann['category_id']])
if with_mask:
# Note polys are not resized
mask_polys = [
p for p in ann['segmentation'] if len(p) >= 6
] # valid polygons have >= 3 points (6 coordinates)
poly_lens = [len(p) for p in mask_polys]
gt_mask_polys.append(mask_polys)
gt_poly_lens.extend(poly_lens)
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)
if with_mask:
ann['mask_polys'] = gt_mask_polys
ann['poly_lens'] = gt_poly_lens
return ann
class CocoDataset(Dataset): class CocoDataset(Dataset):
def __init__(self, def __init__(self,
...@@ -138,7 +79,7 @@ class CocoDataset(Dataset): ...@@ -138,7 +79,7 @@ class CocoDataset(Dataset):
self.img_transform = ImageTransform( self.img_transform = ImageTransform(
size_divisor=self.size_divisor, **self.img_norm_cfg) size_divisor=self.size_divisor, **self.img_norm_cfg)
self.bbox_transform = BboxTransform() self.bbox_transform = BboxTransform()
self.mask_transform = PolyMaskTransform() self.mask_transform = MaskTransform()
self.numpy2tensor = Numpy2Tensor() self.numpy2tensor = Numpy2Tensor()
def __len__(self): def __len__(self):
...@@ -162,6 +103,67 @@ class CocoDataset(Dataset): ...@@ -162,6 +103,67 @@ class CocoDataset(Dataset):
ann_info = self.coco.loadAnns(ann_ids) ann_info = self.coco.loadAnns(ann_ids)
return ann_info return ann_info
def _parse_ann_info(self, ann_info, with_mask=True):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, mask_polys, poly_lens.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
# each mask consists of one or several polys, each poly is a list of float.
if with_mask:
gt_masks = []
gt_mask_polys = []
gt_poly_lens = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
if ann['iscrowd']:
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
if with_mask:
gt_masks.append(self.coco.annToMask(ann))
mask_polys = [
p for p in ann['segmentation'] if len(p) >= 6
] # valid polygons have >= 3 points (6 coordinates)
poly_lens = [len(p) for p in mask_polys]
gt_mask_polys.append(mask_polys)
gt_poly_lens.extend(poly_lens)
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)
if with_mask:
ann['masks'] = gt_masks
# poly format is not used in the current implementation
ann['mask_polys'] = gt_mask_polys
ann['poly_lens'] = gt_poly_lens
return ann
def _set_group_flag(self): def _set_group_flag(self):
"""Set flag according to image aspect ratio. """Set flag according to image aspect ratio.
...@@ -200,7 +202,7 @@ class CocoDataset(Dataset): ...@@ -200,7 +202,7 @@ class CocoDataset(Dataset):
idx = self._rand_another(idx) idx = self._rand_another(idx)
continue continue
ann = parse_ann_info(ann_info, self.cat2label, self.with_mask) ann = self._parse_ann_info(ann_info, self.with_mask)
gt_bboxes = ann['bboxes'] gt_bboxes = ann['bboxes']
gt_labels = ann['labels'] gt_labels = ann['labels']
gt_bboxes_ignore = ann['bboxes_ignore'] gt_bboxes_ignore = ann['bboxes_ignore']
...@@ -223,10 +225,8 @@ class CocoDataset(Dataset): ...@@ -223,10 +225,8 @@ class CocoDataset(Dataset):
scale_factor, flip) scale_factor, flip)
if self.with_mask: if self.with_mask:
gt_mask_polys, gt_poly_lens, num_polys_per_mask = \ gt_masks = self.mask_transform(ann['masks'], pad_shape,
self.mask_transform( scale_factor, flip)
ann['mask_polys'], ann['poly_lens'],
img_info['height'], img_info['width'], flip)
ori_shape = (img_info['height'], img_info['width'], 3) ori_shape = (img_info['height'], img_info['width'], 3)
img_meta = dict( img_meta = dict(
...@@ -247,10 +247,7 @@ class CocoDataset(Dataset): ...@@ -247,10 +247,7 @@ class CocoDataset(Dataset):
if self.with_crowd: if self.with_crowd:
data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore)) data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
if self.with_mask: if self.with_mask:
data['gt_masks'] = dict( data['gt_masks'] = DC(gt_masks, cpu_only=True)
polys=DC(gt_mask_polys, cpu_only=True),
poly_lens=DC(gt_poly_lens, cpu_only=True),
polys_per_mask=DC(num_polys_per_mask, cpu_only=True))
return data return data
def prepare_test_img(self, idx): def prepare_test_img(self, idx):
......
...@@ -10,7 +10,8 @@ __all__ = [ ...@@ -10,7 +10,8 @@ __all__ = [
class ImageTransform(object): class ImageTransform(object):
"""Preprocess an image """Preprocess an image.
1. rescale the image to expected size 1. rescale the image to expected size
2. normalize the image 2. normalize the image
3. flip the image (if needed) 3. flip the image (if needed)
...@@ -59,7 +60,8 @@ def bbox_flip(bboxes, img_shape): ...@@ -59,7 +60,8 @@ def bbox_flip(bboxes, img_shape):
class BboxTransform(object): class BboxTransform(object):
"""Preprocess gt bboxes """Preprocess gt bboxes.
1. rescale bboxes according to image size 1. rescale bboxes according to image size
2. flip bboxes (if needed) 2. flip bboxes (if needed)
3. pad the first dimension to `max_num_gts` 3. pad the first dimension to `max_num_gts`
...@@ -84,17 +86,12 @@ class BboxTransform(object): ...@@ -84,17 +86,12 @@ class BboxTransform(object):
class PolyMaskTransform(object): class PolyMaskTransform(object):
"""Preprocess polygons."""
def __init__(self): def __init__(self):
pass pass
def __call__(self, gt_mask_polys, gt_poly_lens, img_h, img_w, flip=False): def __call__(self, gt_mask_polys, gt_poly_lens, img_h, img_w, flip=False):
"""
Args:
gt_mask_polys(list): a list of masks, each mask is a list of polys,
each poly is a list of numbers
gt_poly_lens(list): a list of int, indicating the size of each poly
"""
if flip: if flip:
gt_mask_polys = segms.flip_segms(gt_mask_polys, img_h, img_w) gt_mask_polys = segms.flip_segms(gt_mask_polys, img_h, img_w)
num_polys_per_mask = np.array( num_polys_per_mask = np.array(
...@@ -108,6 +105,28 @@ class PolyMaskTransform(object): ...@@ -108,6 +105,28 @@ class PolyMaskTransform(object):
return gt_mask_polys, gt_poly_lens, num_polys_per_mask return gt_mask_polys, gt_poly_lens, num_polys_per_mask
class MaskTransform(object):
"""Preprocess masks.
1. resize masks to expected size and stack to a single array
2. flip the masks (if needed)
3. pad the masks (if needed)
"""
def __call__(self, masks, pad_shape, scale_factor, flip=False):
masks = [
mmcv.imrescale(mask, scale_factor, interpolation='nearest')
for mask in masks
]
if flip:
masks = [mask[:, ::-1] for mask in masks]
padded_masks = [
mmcv.impad(mask, pad_shape[:2], pad_val=0) for mask in masks
]
padded_masks = np.stack(padded_masks, axis=0)
return padded_masks
class Numpy2Tensor(object): class Numpy2Tensor(object):
def __init__(self): def __init__(self):
......
...@@ -108,8 +108,8 @@ class MaskTestMixin(object): ...@@ -108,8 +108,8 @@ class MaskTestMixin(object):
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois) x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks( segm_result = self.mask_head.get_seg_masks(
mask_pred, det_bboxes, det_labels, self.test_cfg.rcnn, mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
ori_shape) scale_factor, rescale)
return segm_result return segm_result
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels): def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import bbox2roi, bbox2result, split_combined_polys, multi_apply from mmdet.core import bbox2roi, bbox2result, multi_apply
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...@@ -124,9 +124,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -124,9 +124,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(loss_bbox) losses.update(loss_bbox)
if self.with_mask: if self.with_mask:
gt_polys = split_combined_polys(**gt_masks)
mask_targets = self.mask_head.get_mask_target( mask_targets = self.mask_head.get_mask_target(
pos_proposals, pos_assigned_gt_inds, gt_polys, img_meta, pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn) self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals) pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor( mask_feats = self.mask_roi_extractor(
......
...@@ -87,9 +87,9 @@ class FCNMaskHead(nn.Module): ...@@ -87,9 +87,9 @@ class FCNMaskHead(nn.Module):
return mask_pred return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
img_meta, rcnn_train_cfg): rcnn_train_cfg):
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, img_meta, rcnn_train_cfg) gt_masks, rcnn_train_cfg)
return mask_targets return mask_targets
def loss(self, mask_pred, mask_targets, labels): def loss(self, mask_pred, mask_targets, labels):
...@@ -99,8 +99,9 @@ class FCNMaskHead(nn.Module): ...@@ -99,8 +99,9 @@ class FCNMaskHead(nn.Module):
return loss return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_shape): ori_shape, scale_factor, rescale):
"""Get segmentation masks from mask_pred and bboxes """Get segmentation masks from mask_pred and bboxes.
Args: Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
For single-scale testing, mask_pred is the direct output of For single-scale testing, mask_pred is the direct output of
...@@ -111,6 +112,7 @@ class FCNMaskHead(nn.Module): ...@@ -111,6 +112,7 @@ class FCNMaskHead(nn.Module):
img_shape (Tensor): shape (3, ) img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config rcnn_test_cfg (dict): rcnn testing config
ori_shape: original image size ori_shape: original image size
Returns: Returns:
list[list]: encoded masks list[list]: encoded masks
""" """
...@@ -119,65 +121,34 @@ class FCNMaskHead(nn.Module): ...@@ -119,65 +121,34 @@ class FCNMaskHead(nn.Module):
assert isinstance(mask_pred, np.ndarray) assert isinstance(mask_pred, np.ndarray)
cls_segms = [[] for _ in range(self.num_classes - 1)] cls_segms = [[] for _ in range(self.num_classes - 1)]
mask_size = mask_pred.shape[-1]
bboxes = det_bboxes.cpu().numpy()[:, :4] bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1 labels = det_labels.cpu().numpy() + 1
img_h = ori_shape[0]
img_w = ori_shape[1]
scale = (mask_size + 2.0) / mask_size if rescale:
bboxes = np.round(self._bbox_scaling(bboxes, scale)).astype(np.int32) img_h, img_w = ori_shape[:2]
padded_mask = np.zeros( else:
(mask_size + 2, mask_size + 2), dtype=np.float32) img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
scale_factor = 1.0
for i in range(bboxes.shape[0]): for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int) bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
label = labels[i] label = labels[i]
w = bbox[2] - bbox[0] + 1 w = max(bbox[2] - bbox[0] + 1, 1)
h = bbox[3] - bbox[1] + 1 h = max(bbox[3] - bbox[1] + 1, 1)
w = max(w, 1)
h = max(h, 1)
if not self.class_agnostic: if not self.class_agnostic:
padded_mask[1:-1, 1:-1] = mask_pred[i, label, :, :] mask_pred_ = mask_pred[i, label, :, :]
else: else:
padded_mask[1:-1, 1:-1] = mask_pred[i, 0, :, :] mask_pred_ = mask_pred[i, 0, :, :]
mask = mmcv.imresize(padded_mask, (w, h))
mask = np.array(
mask > rcnn_test_cfg.mask_thr_binary, dtype=np.uint8)
im_mask = np.zeros((img_h, img_w), dtype=np.uint8) im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
x0 = max(bbox[0], 0) bbox_mask = mmcv.imresize(mask_pred_, (w, h))
x1 = min(bbox[2] + 1, img_w) bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
y0 = max(bbox[1], 0) np.uint8)
y1 = min(bbox[3] + 1, img_h) im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
im_mask[y0:y1, x0:x1] = mask[(y0 - bbox[1]):(y1 - bbox[1]), (
x0 - bbox[0]):(x1 - bbox[0])]
rle = mask_util.encode( rle = mask_util.encode(
np.array(im_mask[:, :, np.newaxis], order='F'))[0] np.array(im_mask[:, :, np.newaxis], order='F'))[0]
cls_segms[label - 1].append(rle) cls_segms[label - 1].append(rle)
return cls_segms
def _bbox_scaling(self, bboxes, scale, clip_shape=None): return cls_segms
"""Scaling bboxes and clip the boundary(optional)
Args:
bboxes(ndarray): shape(..., 4)
scale(float): scaling factor
clip(None or tuple): (h, w)
Returns:
ndarray: scaled bboxes
"""
if float(scale) == 1.0:
scaled_bboxes = bboxes.copy()
else:
w = bboxes[..., 2] - bboxes[..., 0] + 1
h = bboxes[..., 3] - bboxes[..., 1] + 1
dw = (w * (scale - 1)) * 0.5
dh = (h * (scale - 1)) * 0.5
scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
if clip_shape is not None:
return bbox_clip(scaled_bboxes, clip_shape)
else:
return scaled_bboxes
...@@ -91,6 +91,7 @@ def main(): ...@@ -91,6 +91,7 @@ def main():
cfg.gpus = args.gpus cfg.gpus = args.gpus
# add mmdet version to checkpoint as meta data # add mmdet version to checkpoint as meta data
cfg.checkpoint_config.meta = dict(mmdet_version=__version__) cfg.checkpoint_config.meta = dict(mmdet_version=__version__)
cfg.checkpoint_config.config = cfg.text
logger = get_logger(cfg.log_level) logger = get_logger(cfg.log_level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment