Commit 57f6da5c authored by bailuo's avatar bailuo
Browse files

readme

parents
from .build_loader import build_dataloader
from .sampler import DistributedGroupSampler, GroupSampler
__all__ = ['GroupSampler', 'DistributedGroupSampler', 'build_dataloader']
import platform
from functools import partial
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from torch.utils.data import DataLoader
from .sampler import DistributedGroupSampler, DistributedSampler, GroupSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
if dist:
rank, world_size = get_dist_info()
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
world_size, rank)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
**kwargs)
return data_loader
from __future__ import division
import math
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler
class DistributedSampler(_DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
assert hasattr(dataset, 'flag')
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64)
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes):
if size == 0:
continue
indice = np.where(self.flag == i)[0]
assert len(indice) == size
np.random.shuffle(indice)
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice)
indice = np.concatenate(
[indice, np.random.choice(indice, num_extra)])
indices.append(indice)
indices = np.concatenate(indices)
indices = [
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu))
]
indices = np.concatenate(indices)
indices = indices.astype(np.int64).tolist()
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert hasattr(self.dataset, 'flag')
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += int(
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
self.num_replicas)) * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
indice = indice[list(torch.randperm(int(size),
generator=g))].tolist()
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas)
) * self.samples_per_gpu * self.num_replicas - len(indice)
# pad indice
tmp = indice.copy()
for _ in range(extra // size):
indice.extend(tmp)
indice.extend(tmp[:extra % size])
indices.extend(indice)
assert len(indices) == self.total_size
indices = [
indices[j] for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g))
for j in range(i * self.samples_per_gpu, (i + 1) *
self.samples_per_gpu)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
from .compose import Compose
from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
Transpose, to_tensor)
from .instaboost import InstaBoost
from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
from .test_aug import MultiScaleFlipAug
from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip, Resize,
SegRescale)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad',
'RandomCrop', 'Normalize', 'SegRescale', 'MinIoURandomCrop', 'Expand',
'PhotoMetricDistortion', 'Albu', 'InstaBoost'
]
import collections
from mmdet.utils import build_from_cfg
from ..registry import PIPELINES
@PIPELINES.register_module
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from ..registry import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError('type {} cannot be converted to tensor.'.format(
type(data)))
@PIPELINES.register_module
class ToTensor(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + '(keys={})'.format(self.keys)
@PIPELINES.register_module
class ImageToTensor(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + '(keys={})'.format(self.keys)
@PIPELINES.register_module
class Transpose(object):
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + '(keys={}, order={})'.format(
self.keys, self.order)
@PIPELINES.register_module
class ToDataContainer(object):
def __init__(self,
fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
dict(key='gt_labels'))):
self.fields = fields
def __call__(self, results):
for field in self.fields:
field = field.copy()
key = field.pop('key')
results[key] = DC(results[key], **field)
return results
def __repr__(self):
return self.__class__.__name__ + '(fields={})'.format(self.fields)
@PIPELINES.register_module
class DefaultFormatBundle(object):
"""Default formatting bundle.
It simplifies the pipeline of formatting common fields, including "img",
"proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
These fields are formatted as follows.
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- proposals: (1)to tensor, (2)to DataContainer
- gt_bboxes: (1)to tensor, (2)to DataContainer
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
- gt_labels: (1)to tensor, (2)to DataContainer
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
(3)to DataContainer (stack=True)
"""
def __call__(self, results):
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
results['img'] = DC(to_tensor(img), stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
if key not in results:
continue
results[key] = DC(to_tensor(results[key]))
if 'gt_masks' in results:
results['gt_masks'] = DC(results['gt_masks'], cpu_only=True)
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
return results
def __repr__(self):
return self.__class__.__name__
@PIPELINES.register_module
class Collect(object):
"""
Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "proposals", "gt_bboxes",
"gt_bboxes_ignore", "gt_labels", and/or "gt_masks".
The "img_meta" item is always populated. The contents of the "img_meta"
dictionary depends on "meta_keys". By default this includes:
- "img_shape": shape of the image input to the network as a tuple
(h, w, c). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
- "scale_factor": a float indicating the preprocessing scale
- "flip": a boolean indicating if image flip transform was used
- "filename": path to the image file
- "ori_shape": original shape of the image as a tuple (h, w, c)
- "pad_shape": image shape after padding
- "img_norm_cfg": a dict of normalization information:
- mean - per channel mean subtraction
- std - per channel std divisor
- to_rgb - bool indicating if bgr was converted to rgb
"""
def __init__(self,
keys,
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'img_norm_cfg')):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data['img_meta'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + '(keys={}, meta_keys={})'.format(
self.keys, self.meta_keys)
import numpy as np
from ..registry import PIPELINES
@PIPELINES.register_module
class InstaBoost(object):
"""
Data augmentation method in paper "InstaBoost: Boosting Instance
Segmentation Via Probability Map Guided Copy-Pasting"
Implementation details can refer to https://github.com/GothicAi/Instaboost.
"""
def __init__(self,
action_candidate=('normal', 'horizontal', 'skip'),
action_prob=(1, 0, 0),
scale=(0.8, 1.2),
dx=15,
dy=15,
theta=(-1, 1),
color_prob=0.5,
hflag=False,
aug_ratio=0.5):
try:
import instaboostfast as instaboost
except ImportError:
raise ImportError(
'Please run "pip install instaboostfast" '
'to install instaboostfast first for instaboost augmentation.')
self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
scale, dx, dy, theta,
color_prob, hflag)
self.aug_ratio = aug_ratio
def _load_anns(self, results):
labels = results['ann_info']['labels']
masks = results['ann_info']['masks']
bboxes = results['ann_info']['bboxes']
n = len(labels)
anns = []
for i in range(n):
label = labels[i]
bbox = bboxes[i]
mask = masks[i]
x1, y1, x2, y2 = bbox
bbox = [x1, y1, x2 - x1 + 1, y2 - y1 + 1]
anns.append({
'category_id': label,
'segmentation': mask,
'bbox': bbox
})
return anns
def _parse_anns(self, results, anns, img):
gt_bboxes = []
gt_labels = []
gt_masks_ann = []
for ann in anns:
x1, y1, w, h = ann['bbox']
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
gt_bboxes.append(bbox)
gt_labels.append(ann['category_id'])
gt_masks_ann.append(ann['segmentation'])
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
results['ann_info']['labels'] = gt_labels
results['ann_info']['bboxes'] = gt_bboxes
results['ann_info']['masks'] = gt_masks_ann
results['img'] = img
return results
def __call__(self, results):
img = results['img']
anns = self._load_anns(results)
if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
try:
import instaboostfast as instaboost
except ImportError:
raise ImportError('Please run "pip install instaboostfast" '
'to install instaboostfast first.')
anns, img = instaboost.get_new_data(
anns, img, self.cfg, background=None)
results = self._parse_anns(results, anns, img)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += ('(cfg={}, aug_ratio={})').format(self.cfg, self.aug_ratio)
return repr_str
import os.path as osp
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
from ..registry import PIPELINES
@PIPELINES.register_module
class LoadImageFromFile(object):
def __init__(self, to_float32=False, color_type='color'):
self.to_float32 = to_float32
self.color_type = color_type
def __call__(self, results):
if results['img_prefix'] is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']
img = mmcv.imread(filename, self.color_type)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
return results
def __repr__(self):
return '{} (to_float32={}, color_type={})'.format(
self.__class__.__name__, self.to_float32, self.color_type)
@PIPELINES.register_module
class LoadAnnotations(object):
def __init__(self,
with_bbox=True,
with_label=True,
with_mask=False,
with_seg=False,
poly2mask=True):
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
self.with_seg = with_seg
self.poly2mask = poly2mask
def _load_bboxes(self, results):
ann_info = results['ann_info']
results['gt_bboxes'] = ann_info['bboxes']
gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
if gt_bboxes_ignore is not None:
results['gt_bboxes_ignore'] = gt_bboxes_ignore
results['bbox_fields'].append('gt_bboxes_ignore')
results['bbox_fields'].append('gt_bboxes')
return results
def _load_labels(self, results):
results['gt_labels'] = results['ann_info']['labels']
return results
def _poly2mask(self, mask_ann, img_h, img_w):
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def _load_masks(self, results):
h, w = results['img_info']['height'], results['img_info']['width']
gt_masks = results['ann_info']['masks']
if self.poly2mask:
gt_masks = [self._poly2mask(mask, h, w) for mask in gt_masks]
results['gt_masks'] = gt_masks
results['mask_fields'].append('gt_masks')
return results
def _load_semantic_seg(self, results):
results['gt_semantic_seg'] = mmcv.imread(
osp.join(results['seg_prefix'], results['ann_info']['seg_map']),
flag='unchanged').squeeze()
results['seg_fields'].append('gt_semantic_seg')
return results
def __call__(self, results):
if self.with_bbox:
results = self._load_bboxes(results)
if results is None:
return None
if self.with_label:
results = self._load_labels(results)
if self.with_mask:
results = self._load_masks(results)
if self.with_seg:
results = self._load_semantic_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += ('(with_bbox={}, with_label={}, with_mask={},'
' with_seg={})').format(self.with_bbox, self.with_label,
self.with_mask, self.with_seg)
return repr_str
@PIPELINES.register_module
class LoadProposals(object):
def __init__(self, num_max_proposals=None):
self.num_max_proposals = num_max_proposals
def __call__(self, results):
proposals = results['proposals']
if proposals.shape[1] not in (4, 5):
raise AssertionError(
'proposals should have shapes (n, 4) or (n, 5), '
'but found {}'.format(proposals.shape))
proposals = proposals[:, :4]
if self.num_max_proposals is not None:
proposals = proposals[:self.num_max_proposals]
if len(proposals) == 0:
proposals = np.array([[0, 0, 0, 0]], dtype=np.float32)
results['proposals'] = proposals
results['bbox_fields'].append('proposals')
return results
def __repr__(self):
return self.__class__.__name__ + '(num_max_proposals={})'.format(
self.num_max_proposals)
import mmcv
from ..registry import PIPELINES
from .compose import Compose
@PIPELINES.register_module
class MultiScaleFlipAug(object):
def __init__(self, transforms, img_scale, flip=False):
self.transforms = Compose(transforms)
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
self.flip = flip
def __call__(self, results):
aug_data = []
flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale:
for flip in flip_aug:
_results = results.copy()
_results['scale'] = scale
_results['flip'] = flip
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(transforms={}, img_scale={}, flip={})'.format(
self.transforms, self.img_scale, self.flip)
return repr_str
import inspect
import mmcv
import numpy as np
from numpy import random
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from ..registry import PIPELINES
try:
from imagecorruptions import corrupt
except ImportError:
corrupt = None
try:
import albumentations
from albumentations import Compose
except ImportError:
albumentations = None
Compose = None
@PIPELINES.register_module
class Resize(object):
"""Resize images & bbox & mask.
This transform resizes the input image to some scale. Bboxes and masks are
then resized with the same scale factor. If the input dict contains the key
"scale", then the scale in the input dict is used, otherwise the specified
scale in the init method is used.
`img_scale` can either be a tuple (single-scale) or a list of tuple
(multi-scale). There are 3 multiscale modes:
- `ratio_range` is not None: randomly sample a ratio from the ratio range
and multiply it with the image scale.
- `ratio_range` is None and `multiscale_mode` == "range": randomly sample a
scale from the a range.
- `ratio_range` is None and `multiscale_mode` == "value": randomly sample a
scale from multiple scales.
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given a scale and a range of image ratio
assert len(self.img_scale) == 1
else:
# mode 2: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
@staticmethod
def random_select(img_scales):
assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _resize_img(self, results):
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
results['img'], results['scale'], return_scale=True)
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'], results['scale'], return_scale=True)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape # in case that there is no padding
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
def _resize_bboxes(self, results):
img_shape = results['img_shape']
for key in results.get('bbox_fields', []):
bboxes = results[key] * results['scale_factor']
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1] - 1)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0] - 1)
results[key] = bboxes
def _resize_masks(self, results):
for key in results.get('mask_fields', []):
if results[key] is None:
continue
if self.keep_ratio:
masks = [
mmcv.imrescale(
mask, results['scale_factor'], interpolation='nearest')
for mask in results[key]
]
else:
mask_size = (results['img_shape'][1], results['img_shape'][0])
masks = [
mmcv.imresize(mask, mask_size, interpolation='nearest')
for mask in results[key]
]
results[key] = np.stack(masks)
def _resize_seg(self, results):
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
results['gt_semantic_seg'] = gt_seg
def __call__(self, results):
if 'scale' not in results:
self._random_scale(results)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_masks(results)
self._resize_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += ('(img_scale={}, multiscale_mode={}, ratio_range={}, '
'keep_ratio={})').format(self.img_scale,
self.multiscale_mode,
self.ratio_range,
self.keep_ratio)
return repr_str
@PIPELINES.register_module
class RandomFlip(object):
"""Flip the image & bbox & mask.
If the input dict contains the key "flip", then the flag will be used,
otherwise it will be randomly decided by a ratio specified in the init
method.
Args:
flip_ratio (float, optional): The flipping probability.
"""
def __init__(self, flip_ratio=None, direction='horizontal'):
self.flip_ratio = flip_ratio
self.direction = direction
if flip_ratio is not None:
assert flip_ratio >= 0 and flip_ratio <= 1
assert direction in ['horizontal', 'vertical']
def bbox_flip(self, bboxes, img_shape, direction):
"""Flip bboxes horizontally.
Args:
bboxes(ndarray): shape (..., 4*k)
img_shape(tuple): (height, width)
"""
assert bboxes.shape[-1] % 4 == 0
flipped = bboxes.copy()
if direction == 'horizontal':
w = img_shape[1]
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
elif direction == 'vertical':
h = img_shape[0]
flipped[..., 1::4] = h - bboxes[..., 3::4] - 1
flipped[..., 3::4] = h - bboxes[..., 1::4] - 1
else:
raise ValueError(
'Invalid flipping direction "{}"'.format(direction))
return flipped
def __call__(self, results):
if 'flip' not in results:
flip = True if np.random.rand() < self.flip_ratio else False
results['flip'] = flip
if 'flip_direction' not in results:
results['flip_direction'] = self.direction
if results['flip']:
# flip image
results['img'] = mmcv.imflip(
results['img'], direction=results['flip_direction'])
# flip bboxes
for key in results.get('bbox_fields', []):
results[key] = self.bbox_flip(results[key],
results['img_shape'],
results['flip_direction'])
# flip masks
for key in results.get('mask_fields', []):
results[key] = np.stack([
mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key]
])
# flip segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction'])
return results
def __repr__(self):
return self.__class__.__name__ + '(flip_ratio={})'.format(
self.flip_ratio)
@PIPELINES.register_module
class Pad(object):
"""Pad the image & mask.
There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number.
Args:
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value, 0 by default.
"""
def __init__(self, size=None, size_divisor=None, pad_val=0):
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
# only one of size and size_divisor should be valid
assert size is not None or size_divisor is not None
assert size is None or size_divisor is None
def _pad_img(self, results):
if self.size is not None:
padded_img = mmcv.impad(results['img'], self.size)
elif self.size_divisor is not None:
padded_img = mmcv.impad_to_multiple(
results['img'], self.size_divisor, pad_val=self.pad_val)
results['img'] = padded_img
results['pad_shape'] = padded_img.shape
results['pad_fixed_size'] = self.size
results['pad_size_divisor'] = self.size_divisor
def _pad_masks(self, results):
pad_shape = results['pad_shape'][:2]
for key in results.get('mask_fields', []):
padded_masks = [
mmcv.impad(mask, pad_shape, pad_val=self.pad_val)
for mask in results[key]
]
if padded_masks:
results[key] = np.stack(padded_masks, axis=0)
else:
results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8)
def _pad_seg(self, results):
for key in results.get('seg_fields', []):
results[key] = mmcv.impad(results[key], results['pad_shape'][:2])
def __call__(self, results):
self._pad_img(results)
self._pad_masks(results)
self._pad_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(size={}, size_divisor={}, pad_val={})'.format(
self.size, self.size_divisor, self.pad_val)
return repr_str
@PIPELINES.register_module
class Normalize(object):
"""Normalize the image.
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(mean={}, std={}, to_rgb={})'.format(
self.mean, self.std, self.to_rgb)
return repr_str
@PIPELINES.register_module
class RandomCrop(object):
"""Random crop the image & bboxes & masks.
Args:
crop_size (tuple): Expected size after cropping, (h, w).
"""
def __init__(self, crop_size):
self.crop_size = crop_size
def __call__(self, results):
img = results['img']
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
# crop the image
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape
# crop bboxes accordingly and clip to the image boundary
for key in results.get('bbox_fields', []):
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
dtype=np.float32)
bboxes = results[key] - bbox_offset
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1] - 1)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0] - 1)
results[key] = bboxes
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]
# filter out the gt bboxes that are completely cropped
if 'gt_bboxes' in results:
gt_bboxes = results['gt_bboxes']
valid_inds = (gt_bboxes[:, 2] > gt_bboxes[:, 0]) & (
gt_bboxes[:, 3] > gt_bboxes[:, 1])
# if no gt bbox remains after cropping, just skip this image
if not np.any(valid_inds):
return None
results['gt_bboxes'] = gt_bboxes[valid_inds, :]
if 'gt_labels' in results:
results['gt_labels'] = results['gt_labels'][valid_inds]
# filter and crop the masks
if 'gt_masks' in results:
valid_gt_masks = []
for i in np.where(valid_inds)[0]:
gt_mask = results['gt_masks'][i][crop_y1:crop_y2,
crop_x1:crop_x2]
valid_gt_masks.append(gt_mask)
results['gt_masks'] = np.stack(valid_gt_masks)
return results
def __repr__(self):
return self.__class__.__name__ + '(crop_size={})'.format(
self.crop_size)
@PIPELINES.register_module
class SegRescale(object):
"""Rescale semantic segmentation maps.
Args:
scale_factor (float): The scale factor of the final output.
"""
def __init__(self, scale_factor=1):
self.scale_factor = scale_factor
def __call__(self, results):
for key in results.get('seg_fields', []):
if self.scale_factor != 1:
results[key] = mmcv.imrescale(
results[key], self.scale_factor, interpolation='nearest')
return results
def __repr__(self):
return self.__class__.__name__ + '(scale_factor={})'.format(
self.scale_factor)
@PIPELINES.register_module
class PhotoMetricDistortion(object):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, results):
img = results['img']
# random brightness
if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += ('(brightness_delta={}, contrast_range={}, '
'saturation_range={}, hue_delta={})').format(
self.brightness_delta, self.contrast_range,
self.saturation_range, self.hue_delta)
return repr_str
@PIPELINES.register_module
class Expand(object):
"""Random expand the image & bboxes.
Randomly place the original image on a canvas of 'ratio' x original image
size filled with mean values. The ratio is in the range of ratio_range.
Args:
mean (tuple): mean value of dataset.
to_rgb (bool): if need to convert the order of mean to align with RGB.
ratio_range (tuple): range of expand ratio.
prob (float): probability of applying this transformation
"""
def __init__(self,
mean=(0, 0, 0),
to_rgb=True,
ratio_range=(1, 4),
seg_ignore_label=None,
prob=0.5):
self.to_rgb = to_rgb
self.ratio_range = ratio_range
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
self.seg_ignore_label = seg_ignore_label
self.prob = prob
def __call__(self, results):
if random.uniform(0, 1) > self.prob:
return results
img, boxes = [results[k] for k in ('img', 'gt_bboxes')]
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
boxes = boxes + np.tile((left, top), 2).astype(boxes.dtype)
results['img'] = expand_img
results['gt_bboxes'] = boxes
if 'gt_masks' in results:
expand_gt_masks = []
for mask in results['gt_masks']:
expand_mask = np.full((int(h * ratio), int(w * ratio)),
0).astype(mask.dtype)
expand_mask[top:top + h, left:left + w] = mask
expand_gt_masks.append(expand_mask)
results['gt_masks'] = np.stack(expand_gt_masks)
# not tested
if 'gt_semantic_seg' in results:
assert self.seg_ignore_label is not None
gt_seg = results['gt_semantic_seg']
expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
self.seg_ignore_label).astype(gt_seg.dtype)
expand_gt_seg[top:top + h, left:left + w] = gt_seg
results['gt_semantic_seg'] = expand_gt_seg
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(mean={}, to_rgb={}, ratio_range={}, ' \
'seg_ignore_label={})'.format(
self.mean, self.to_rgb, self.ratio_range,
self.seg_ignore_label)
return repr_str
@PIPELINES.register_module
class MinIoURandomCrop(object):
"""Random crop the image & bboxes, the cropped patches have minimum IoU
requirement with original image & bboxes, the IoU threshold is randomly
selected from min_ious.
Args:
min_ious (tuple): minimum IoU threshold for all intersections with
bounding boxes
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size).
"""
def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
# 1: return ori img
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
def __call__(self, results):
img, boxes, labels = [
results[k] for k in ('img', 'gt_bboxes', 'gt_labels')
]
h, w, c = img.shape
while True:
mode = random.choice(self.sample_mode)
if mode == 1:
return results
min_iou = mode
for i in range(50):
new_w = random.uniform(self.min_crop_size * w, w)
new_h = random.uniform(self.min_crop_size * h, h)
# h / w in [0.5, 2]
if new_h / new_w < 0.5 or new_h / new_w > 2:
continue
left = random.uniform(w - new_w)
top = random.uniform(h - new_h)
patch = np.array(
(int(left), int(top), int(left + new_w), int(top + new_h)))
overlaps = bbox_overlaps(
patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
if overlaps.min() < min_iou:
continue
# center of boxes should inside the crop img
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) *
(center[:, 0] < patch[2]) * (center[:, 1] < patch[3]))
if not mask.any():
continue
boxes = boxes[mask]
labels = labels[mask]
# adjust boxes
img = img[patch[1]:patch[3], patch[0]:patch[2]]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)
results['img'] = img
results['gt_bboxes'] = boxes
results['gt_labels'] = labels
if 'gt_masks' in results:
valid_masks = [
results['gt_masks'][i] for i in range(len(mask))
if mask[i]
]
results['gt_masks'] = np.stack([
gt_mask[patch[1]:patch[3], patch[0]:patch[2]]
for gt_mask in valid_masks
])
# not tested
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = results['gt_semantic_seg'][
patch[1]:patch[3], patch[0]:patch[2]]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(min_ious={}, min_crop_size={})'.format(
self.min_ious, self.min_crop_size)
return repr_str
@PIPELINES.register_module
class Corrupt(object):
def __init__(self, corruption, severity=1):
self.corruption = corruption
self.severity = severity
def __call__(self, results):
if corrupt is None:
raise RuntimeError('imagecorruptions is not installed')
results['img'] = corrupt(
results['img'].astype(np.uint8),
corruption_name=self.corruption,
severity=self.severity)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(corruption={}, severity={})'.format(
self.corruption, self.severity)
return repr_str
@PIPELINES.register_module
class Albu(object):
def __init__(self,
transforms,
bbox_params=None,
keymap=None,
update_pad_shape=False,
skip_img_without_anno=False):
"""
Adds custom transformations from Albumentations lib.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
transforms (list): list of albu transformations
bbox_params (dict): bbox_params for albumentation `Compose`
keymap (dict): contains {'input key':'albumentation-style key'}
skip_img_without_anno (bool): whether to skip the image
if no ann left after aug
"""
if Compose is None:
raise RuntimeError('albumentations is not installed')
self.transforms = transforms
self.filter_lost_elements = False
self.update_pad_shape = update_pad_shape
self.skip_img_without_anno = skip_img_without_anno
# A simple workaround to remove masks without boxes
if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
and 'filter_lost_elements' in bbox_params):
self.filter_lost_elements = True
self.origin_label_fields = bbox_params['label_fields']
bbox_params['label_fields'] = ['idx_mapper']
del bbox_params['filter_lost_elements']
self.bbox_params = (
self.albu_builder(bbox_params) if bbox_params else None)
self.aug = Compose([self.albu_builder(t) for t in self.transforms],
bbox_params=self.bbox_params)
if not keymap:
self.keymap_to_albu = {
'img': 'image',
'gt_masks': 'masks',
'gt_bboxes': 'bboxes'
}
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg):
"""Import a module from albumentations.
Inherits some of `build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and "type" in cfg
args = cfg.copy()
obj_type = args.pop("type")
if mmcv.is_str(obj_type):
if albumentations is None:
raise RuntimeError('albumentations is not installed')
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
'type must be a str or valid type, but got {}'.format(
type(obj_type)))
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d, keymap):
"""
Dictionary mapper.
Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict
def __call__(self, results):
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
if 'bboxes' in results:
# to list of boxes
if isinstance(results['bboxes'], np.ndarray):
results['bboxes'] = [x for x in results['bboxes']]
# add pseudo-field for filtration
if self.filter_lost_elements:
results['idx_mapper'] = np.arange(len(results['bboxes']))
results = self.aug(**results)
if 'bboxes' in results:
if isinstance(results['bboxes'], list):
results['bboxes'] = np.array(
results['bboxes'], dtype=np.float32)
results['bboxes'] = results['bboxes'].reshape(-1, 4)
# filter label_fields
if self.filter_lost_elements:
results['idx_mapper'] = np.arange(len(results['bboxes']))
for label in self.origin_label_fields:
results[label] = np.array(
[results[label][i] for i in results['idx_mapper']])
if 'masks' in results:
results['masks'] = np.array(
[results['masks'][i] for i in results['idx_mapper']])
if (not len(results['idx_mapper'])
and self.skip_img_without_anno):
return None
if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
results['gt_labels'] = results['gt_labels'].astype(np.int64)
# back to the original format
results = self.mapper(results, self.keymap_back)
# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(transformations={})'.format(self.transformations)
return repr_str
from mmdet.utils import Registry
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
from .registry import DATASETS
from .xml_style import XMLDataset
@DATASETS.register_module
class VOCDataset(XMLDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
def __init__(self, **kwargs):
super(VOCDataset, self).__init__(**kwargs)
if 'VOC2007' in self.img_prefix:
self.year = 2007
elif 'VOC2012' in self.img_prefix:
self.year = 2012
else:
raise ValueError('Cannot infer dataset year from img_prefix')
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
from .registry import DATASETS
from .xml_style import XMLDataset
@DATASETS.register_module
class WIDERFaceDataset(XMLDataset):
"""
Reader for the WIDER Face dataset in PASCAL VOC format.
Conversion scripts can be found in
https://github.com/sovrasov/wider-face-pascal-voc-annotations
"""
CLASSES = ('face', )
def __init__(self, **kwargs):
super(WIDERFaceDataset, self).__init__(**kwargs)
def load_annotations(self, ann_file):
img_infos = []
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = '{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
folder = root.find('folder').text
img_infos.append(
dict(
id=img_id,
filename=osp.join(folder, filename),
width=width,
height=height))
return img_infos
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class XMLDataset(CustomDataset):
def __init__(self, min_size=None, **kwargs):
super(XMLDataset, self).__init__(**kwargs)
self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)}
self.min_size = min_size
def load_annotations(self, ann_file):
img_infos = []
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = 'JPEGImages/{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
img_infos.append(
dict(id=img_id, filename=filename, width=width, height=height))
return img_infos
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = self.cat2label[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
ignore = False
if self.min_size:
assert not self.test_mode
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
if w < self.min_size or h < self.min_size:
ignore = True
if difficult or ignore:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann
from .anchor_heads import * # noqa: F401,F403
from .backbones import * # noqa: F401,F403
from .bbox_heads import * # noqa: F401,F403
from .builder import (build_backbone, build_detector, build_head, build_loss,
build_neck, build_roi_extractor, build_shared_head)
from .detectors import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .mask_heads import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS)
from .roi_extractors import * # noqa: F401,F403
from .shared_heads import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
'build_shared_head', 'build_head', 'build_loss', 'build_detector'
]
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
from .solo_head import SOLOHead
from .solov2_head import SOLOv2Head
from .solov2_light_head import SOLOv2LightHead
from .decoupled_solo_head import DecoupledSOLOHead
from .decoupled_solo_light_head import DecoupledSOLOLightHead
__all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead', 'SOLOHead', 'SOLOv2Head', 'SOLOv2LightHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
]
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox, force_fp32,
multi_apply, multiclass_nms)
from ..builder import build_loss
from ..registry import HEADS
@HEADS.register_module
class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
Args:
num_classes (int): Number of categories including the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC']
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
else:
self.cls_out_channels = num_classes
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.fp16_enabled = False
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
self._init_layers()
def _init_layers(self):
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
def init_weights(self):
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x):
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): device for returned tensors
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i], device=device)
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w),
device=device)
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
"""
Transform network output for a batch into labeled boxes.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
img_metas (list[dict]): size / scale info for each image
cfg (mmcv.Config): test / postprocessing configuration
rescale (bool): if True, return boxes in original image space
Returns:
list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the class index of the
corresponding box.
Example:
>>> import mmcv
>>> self = AnchorHead(num_classes=9, in_channels=1)
>>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
>>> cfg = mmcv.Config(dict(
>>> score_thr=0.00,
>>> nms=dict(type='nms', iou_thr=1.0),
>>> max_per_img=10))
>>> feat = torch.rand(1, 1, 3, 3)
>>> cls_score, bbox_pred = self.forward_single(feat)
>>> # note the input lists are over different levels, not images
>>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
>>> result_list = self.get_bboxes(cls_scores, bbox_preds,
>>> img_metas, cfg)
>>> det_bboxes, det_labels = result_list[0]
>>> assert len(result_list) == 1
>>> assert det_bboxes.shape[1] == 5
>>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
"""
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(
cls_scores[i].size()[-2:],
self.anchor_strides[i],
device=device) for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_score_list,
bbox_pred_list,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
"""
Transform outputs for a single batch item into labeled boxes.
"""
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_score_list,
bbox_pred_list, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
if self.use_sigmoid_cls:
# Add a dummy background class to the front when using sigmoid
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (PseudoSampler, anchor_inside_flags, bbox2delta,
build_assigner, delta2bbox, force_fp32,
images_to_levels, multi_apply, multiclass_nms, unmap)
from ..builder import build_loss
from ..registry import HEADS
from ..utils import ConvModule, Scale, bias_init_with_prob
from .anchor_head import AnchorHead
def reduce_mean(tensor):
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.reduce_op.SUM)
return tensor
@HEADS.register_module
class ATSSHead(AnchorHead):
"""
Bridging the Gap Between Anchor-based and Anchor-free Detection via
Adaptive Training Sample Selection
ATSS head structure is similar with FCOS, however ATSS use anchor boxes
and assign label by Adaptive Training Sample Selection instead max-iou.
https://arxiv.org/abs/1912.02424
"""
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=1,
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_centerness=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(ATSSHead, self).__init__(
num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
self.loss_centerness = build_loss(loss_centerness)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.atss_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.atss_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
self.atss_centerness = nn.Conv2d(
self.feat_channels, self.num_anchors * 1, 3, padding=1)
self.scales = nn.ModuleList([Scale(1.0) for _ in self.anchor_strides])
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.atss_cls, std=0.01, bias=bias_cls)
normal_init(self.atss_reg, std=0.01)
normal_init(self.atss_centerness, std=0.01)
def forward(self, feats):
return multi_apply(self.forward_single, feats, self.scales)
def forward_single(self, x, scale):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.atss_cls(cls_feat)
# we just follow atss, not apply exp in bbox_pred
bbox_pred = scale(self.atss_reg(reg_feat)).float()
centerness = self.atss_centerness(reg_feat)
return cls_score, bbox_pred, centerness
def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
label_weights, bbox_targets, num_total_samples, cfg):
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# classification loss
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
pos_inds = torch.nonzero(labels).squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_centerness = centerness[pos_inds]
centerness_targets = self.centerness_target(
pos_anchors, pos_bbox_targets)
pos_decode_bbox_pred = delta2bbox(pos_anchors, pos_bbox_pred,
self.target_means,
self.target_stds)
pos_decode_bbox_targets = delta2bbox(pos_anchors, pos_bbox_targets,
self.target_means,
self.target_stds)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=centerness_targets,
avg_factor=1.0)
# centerness loss
loss_centerness = self.loss_centerness(
pos_centerness,
centerness_targets,
avg_factor=num_total_samples)
else:
loss_bbox = loss_cls * 0
loss_centerness = loss_bbox * 0
centerness_targets = torch.tensor(0).cuda()
return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def loss(self,
cls_scores,
bbox_preds,
centernesses,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.atss_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos).cuda()).item()
num_total_samples = max(num_total_samples, 1.0)
losses_cls, losses_bbox, loss_centerness,\
bbox_avg_factor = multi_apply(
self.loss_single,
anchor_list,
cls_scores,
bbox_preds,
centernesses,
labels_list,
label_weights_list,
bbox_targets_list,
num_total_samples=num_total_samples,
cfg=cfg)
bbox_avg_factor = sum(bbox_avg_factor)
bbox_avg_factor = reduce_mean(bbox_avg_factor).item()
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_centerness=loss_centerness)
def centerness_target(self, anchors, bbox_targets):
# only calculate pos centerness targets, otherwise there may be nan
gts = delta2bbox(anchors, bbox_targets, self.target_means,
self.target_stds)
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
l_ = anchors_cx - gts[:, 0]
t_ = anchors_cy - gts[:, 1]
r_ = gts[:, 2] - anchors_cx
b_ = gts[:, 3] - anchors_cy
left_right = torch.stack([l_, r_], dim=1)
top_bottom = torch.stack([t_, b_], dim=1)
centerness = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
assert not torch.isnan(centerness).any()
return centerness
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def get_bboxes(self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(
cls_scores[i].size()[-2:],
self.anchor_strides[i],
device=device) for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
centerness_pred_list = [
centernesses[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
centerness_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
centernesses,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
mlvl_centerness = []
for cls_score, bbox_pred, centerness, anchors in zip(
cls_scores, bbox_preds, centernesses, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
max_scores, _ = (scores * centerness[:, None]).max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
centerness = centerness[topk_inds]
bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_centerness.append(centerness)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
mlvl_centerness = torch.cat(mlvl_centerness)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
def atss_target(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
cfg,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""
almost the same with anchor_target, with a little modification,
here we need return the anchor
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
num_level_anchors_list = [num_level_anchors] * num_imgs
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
self.atss_target_single,
anchor_list,
valid_flag_list,
num_level_anchors_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
cfg=cfg,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors)
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, bbox_weights_list, num_total_pos,
num_total_neg)
def atss_target_single(self,
flat_anchors,
valid_flags,
num_level_anchors,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
cfg,
label_channels=1,
unmap_outputs=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 6
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
num_level_anchors_inside = self.get_num_level_anchors_inside(
num_level_anchors, inside_flags)
bbox_assigner = build_assigner(cfg.assigner)
assign_result = bbox_assigner.assign(anchors, num_level_anchors_inside,
gt_bboxes, gt_bboxes_ignore,
gt_labels)
bbox_sampler = PseudoSampler()
sampling_result = bbox_sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
sampling_result.pos_gt_bboxes,
self.target_means, self.target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors, num_total_anchors, inside_flags)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds)
def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
split_inside_flags = torch.split(inside_flags, num_level_anchors)
num_level_anchors_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_anchors_inside
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class DecoupledSOLOHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.4,
num_grids=None,
cate_down_pos=0,
with_deform=False,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None):
super(DecoupledSOLOHead, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.cate_down_pos = cate_down_pos
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.with_deform = with_deform
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.ins_convs_x = nn.ModuleList()
self.ins_convs_y = nn.ModuleList()
self.cate_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels + 1 if i == 0 else self.seg_feat_channels
self.ins_convs_x.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.ins_convs_y.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.dsolo_ins_list_x = nn.ModuleList()
self.dsolo_ins_list_y = nn.ModuleList()
for seg_num_grid in self.seg_num_grids:
self.dsolo_ins_list_x.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_ins_list_y.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
def init_weights(self):
for m in self.ins_convs_x:
normal_init(m.conv, std=0.01)
for m in self.ins_convs_y:
normal_init(m.conv, std=0.01)
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
bias_ins = bias_init_with_prob(0.01)
for m in self.dsolo_ins_list_x:
normal_init(m, std=0.01, bias=bias_ins)
for m in self.dsolo_ins_list_y:
normal_init(m, std=0.01, bias=bias_ins)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.dsolo_cate, std=0.01, bias=bias_cate)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
ins_pred_x, ins_pred_y, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred_x, ins_pred_y, cate_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_feat = x
cate_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
ins_feat_x = torch.cat([ins_feat, x], 1)
ins_feat_y = torch.cat([ins_feat, y], 1)
for ins_layer_x, ins_layer_y in zip(self.ins_convs_x, self.ins_convs_y):
ins_feat_x = ins_layer_x(ins_feat_x)
ins_feat_y = ins_layer_y(ins_feat_y)
ins_feat_x = F.interpolate(ins_feat_x, scale_factor=2, mode='bilinear')
ins_feat_y = F.interpolate(ins_feat_y, scale_factor=2, mode='bilinear')
ins_pred_x = self.dsolo_ins_list_x[idx](ins_feat_x)
ins_pred_y = self.dsolo_ins_list_y[idx](ins_feat_y)
# cate branch
for i, cate_layer in enumerate(self.cate_convs):
if i == self.cate_down_pos:
seg_num_grid = self.seg_num_grids[idx]
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
cate_feat = cate_layer(cate_feat)
cate_pred = self.dsolo_cate(cate_feat)
if eval:
ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode='bilinear')
ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode='bilinear')
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return ins_pred_x, ins_pred_y, cate_pred
def loss(self,
ins_preds_x,
ins_preds_y,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds_x]
ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
# ins
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
ins_preds_x_final = [torch.cat([ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...]
for ins_preds_level_img_x, ins_ind_labels_level_img in
zip(ins_preds_level_x, ins_ind_labels_level)], 0)
for ins_preds_level_x, ins_ind_labels_level in
zip(ins_preds_x, zip(*ins_ind_label_list_xy))]
ins_preds_y_final = [torch.cat([ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...]
for ins_preds_level_img_y, ins_ind_labels_level_img in
zip(ins_preds_level_y, ins_ind_labels_level)], 0)
for ins_preds_level_y, ins_ind_labels_level in
zip(ins_preds_y, zip(*ins_ind_label_list_xy))]
num_ins = 0.
# dice loss
loss_ins = []
for input_x, input_y, target in zip(ins_preds_x_final, ins_preds_y_final, ins_labels):
mask_n = input_x.size(0)
if mask_n == 0:
continue
num_ins += mask_n
input = (input_x.sigmoid())*(input_y.sigmoid())
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean() * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
ins_ind_label_list_xy = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
ins_label = torch.zeros([num_grid**2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid**2], dtype=torch.bool, device=device)
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
if len(hit_indices) == 0:
ins_label = torch.zeros([1, featmap_size[0], featmap_size[1]], dtype=torch.uint8,
device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label = torch.zeros([1], dtype=torch.bool, device=device)
ins_ind_label_list.append(ins_ind_label)
ins_ind_label_list_xy.append(cate_label.nonzero())
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
# squared
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_ind_label[label] = True
ins_label = ins_label[ins_ind_label]
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label = ins_ind_label[ins_ind_label]
ins_ind_label_list.append(ins_ind_label)
ins_ind_label_list_xy.append(cate_label.nonzero())
return ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy
def get_seg(self, seg_preds_x, seg_preds_y, cate_preds, img_metas, cfg, rescale=None):
assert len(seg_preds_x) == len(cate_preds)
num_levels = len(cate_preds)
featmap_size = seg_preds_x[0].size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list_x = [
seg_preds_x[i][img_id].detach() for i in range(num_levels)
]
seg_pred_list_y = [
seg_preds_y[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0)
seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list_x, seg_pred_list_y,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds_x,
seg_preds_y,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# trans trans_diff.
trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()
trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
strides = torch.ones(trans_size[-1].item(), device=cate_preds.device)
n_stage = len(self.seg_num_grids)
trans_diff[:trans_size[0]] *= 0
seg_diff[:trans_size[0]] *= 0
num_grids[:trans_size[0]] *= self.seg_num_grids[0]
strides[:trans_size[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1]
seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1]
num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.strides[ind_]
# process.
inds = (cate_preds > cfg.score_thr)
cate_scores = cate_preds[inds]
inds = inds.nonzero()
trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
strides = torch.index_select(strides, dim=0, index=inds[:, 0])
y_inds = (inds[:, 0] - trans_diff) // num_grids
x_inds = (inds[:, 0] - trans_diff) % num_grids
y_inds += seg_diff
x_inds += seg_diff
cate_labels = inds[:, 1]
seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]
seg_masks = seg_masks_soft > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
keep = sum_masks > strides
seg_masks_soft = seg_masks_soft[keep, ...]
seg_masks = seg_masks[keep, ...]
cate_scores = cate_scores[keep]
sum_masks = sum_masks[keep]
cate_labels = cate_labels[keep]
# maskness
seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_score
if len(cate_scores) == 0:
return None
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
seg_masks = seg_masks[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
sum_masks = sum_masks[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
keep = cate_scores >= cfg.update_thr
seg_masks_soft = seg_masks_soft[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_masks_soft,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class DecoupledSOLOLightHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.4,
num_grids=None,
cate_down_pos=0,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None,
use_dcn_in_tower=False,
type_dcn=None):
super(DecoupledSOLOLightHead, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.cate_down_pos = cate_down_pos
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_dcn_in_tower = use_dcn_in_tower
self.type_dcn = type_dcn
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.ins_convs = nn.ModuleList()
self.cate_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.use_dcn_in_tower and i == self.stacked_convs - 1:
cfg_conv = dict(type=self.type_dcn)
else:
cfg_conv = self.conv_cfg
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.ins_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.dsolo_ins_list_x = nn.ModuleList()
self.dsolo_ins_list_y = nn.ModuleList()
for seg_num_grid in self.seg_num_grids:
self.dsolo_ins_list_x.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_ins_list_y.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
def init_weights(self):
for m in self.ins_convs:
normal_init(m.conv, std=0.01)
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
bias_ins = bias_init_with_prob(0.01)
for m in self.dsolo_ins_list_x:
normal_init(m, std=0.01, bias=bias_ins)
for m in self.dsolo_ins_list_y:
normal_init(m, std=0.01, bias=bias_ins)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.dsolo_cate, std=0.01, bias=bias_cate)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
ins_pred_x, ins_pred_y, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred_x, ins_pred_y, cate_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_feat = x
cate_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_feat = torch.cat([ins_feat, coord_feat], 1)
for ins_layer in self.ins_convs:
ins_feat = ins_layer(ins_feat)
ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
ins_pred_x = self.dsolo_ins_list_x[idx](ins_feat)
ins_pred_y = self.dsolo_ins_list_y[idx](ins_feat)
# cate branch
for i, cate_layer in enumerate(self.cate_convs):
if i == self.cate_down_pos:
seg_num_grid = self.seg_num_grids[idx]
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
cate_feat = cate_layer(cate_feat)
cate_pred = self.dsolo_cate(cate_feat)
if eval:
ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode='bilinear')
ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode='bilinear')
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return ins_pred_x, ins_pred_y, cate_pred
def loss(self,
ins_preds_x,
ins_preds_y,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds_x]
ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
# ins
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
ins_preds_x_final = [torch.cat([ins_preds_level_img_x[ins_ind_labels_level_img[:, 1], ...]
for ins_preds_level_img_x, ins_ind_labels_level_img in
zip(ins_preds_level_x, ins_ind_labels_level)], 0)
for ins_preds_level_x, ins_ind_labels_level in
zip(ins_preds_x, zip(*ins_ind_label_list_xy))]
ins_preds_y_final = [torch.cat([ins_preds_level_img_y[ins_ind_labels_level_img[:, 0], ...]
for ins_preds_level_img_y, ins_ind_labels_level_img in
zip(ins_preds_level_y, ins_ind_labels_level)], 0)
for ins_preds_level_y, ins_ind_labels_level in
zip(ins_preds_y, zip(*ins_ind_label_list_xy))]
num_ins = 0.
# dice loss
loss_ins = []
for input_x, input_y, target in zip(ins_preds_x_final, ins_preds_y_final, ins_labels):
mask_n = input_x.size(0)
if mask_n == 0:
continue
num_ins += mask_n
input = (input_x.sigmoid())*(input_y.sigmoid())
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean() * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
ins_ind_label_list_xy = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
ins_label = torch.zeros([num_grid**2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid**2], dtype=torch.bool, device=device)
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
if len(hit_indices) == 0:
ins_label = torch.zeros([1, featmap_size[0], featmap_size[1]], dtype=torch.uint8,
device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label = torch.zeros([1], dtype=torch.bool, device=device)
ins_ind_label_list.append(ins_ind_label)
ins_ind_label_list_xy.append(cate_label.nonzero())
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
# squared
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_ind_label[label] = True
ins_label = ins_label[ins_ind_label]
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label = ins_ind_label[ins_ind_label]
ins_ind_label_list.append(ins_ind_label)
ins_ind_label_list_xy.append(cate_label.nonzero())
return ins_label_list, cate_label_list, ins_ind_label_list, ins_ind_label_list_xy
def get_seg(self, seg_preds_x, seg_preds_y, cate_preds, img_metas, cfg, rescale=None):
assert len(seg_preds_x) == len(cate_preds)
num_levels = len(cate_preds)
featmap_size = seg_preds_x[0].size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list_x = [
seg_preds_x[i][img_id].detach() for i in range(num_levels)
]
seg_pred_list_y = [
seg_preds_y[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0)
seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list_x, seg_pred_list_y,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds_x,
seg_preds_y,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# trans trans_diff.
trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()
trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
strides = torch.ones(trans_size[-1].item(), device=cate_preds.device)
n_stage = len(self.seg_num_grids)
trans_diff[:trans_size[0]] *= 0
seg_diff[:trans_size[0]] *= 0
num_grids[:trans_size[0]] *= self.seg_num_grids[0]
strides[:trans_size[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1]
seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1]
num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.strides[ind_]
# process.
inds = (cate_preds > cfg.score_thr)
cate_scores = cate_preds[inds]
inds = inds.nonzero()
trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
strides = torch.index_select(strides, dim=0, index=inds[:, 0])
y_inds = (inds[:, 0] - trans_diff) // num_grids
x_inds = (inds[:, 0] - trans_diff) % num_grids
y_inds += seg_diff
x_inds += seg_diff
cate_labels = inds[:, 1]
seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]
seg_masks = seg_masks_soft > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
keep = sum_masks > strides
seg_masks_soft = seg_masks_soft[keep, ...]
seg_masks = seg_masks[keep, ...]
cate_scores = cate_scores[keep]
sum_masks = sum_masks[keep]
cate_labels = cate_labels[keep]
# maskness
seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_score
if len(cate_scores) == 0:
return None
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
seg_masks = seg_masks[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
sum_masks = sum_masks[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
keep = cate_scores >= cfg.update_thr
seg_masks_soft = seg_masks_soft[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_masks_soft,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment