Unverified Commit 6efefa27 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #20 from open-mmlab/dev

Initial public release
parents 2cf13281 54b54d88
from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann
__all__ = [
'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann'
]
import os.path as osp
import mmcv
import numpy as np
from mmcv.parallel import DataContainer as DC
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from .transforms import (ImageTransform, BboxTransform, MaskTransform,
Numpy2Tensor)
from .utils import to_tensor, show_ann, random_scale
class CocoDataset(Dataset):
def __init__(self,
ann_file,
img_prefix,
img_scale,
img_norm_cfg,
size_divisor=None,
proposal_file=None,
num_max_proposals=1000,
flip_ratio=0,
with_mask=True,
with_crowd=True,
with_label=True,
test_mode=False,
debug=False):
# path of the data file
self.coco = COCO(ann_file)
# filter images with no annotation during training
if not test_mode:
self.img_ids, self.img_infos = self._filter_imgs()
else:
self.img_ids = self.coco.getImgIds()
self.img_infos = [
self.coco.loadImgs(idx)[0] for idx in self.img_ids
]
assert len(self.img_ids) == len(self.img_infos)
# get the mapping from original category ids to labels
self.cat_ids = self.coco.getCatIds()
self.cat2label = {
cat_id: i + 1
for i, cat_id in enumerate(self.cat_ids)
}
# prefix of images path
self.img_prefix = img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self.img_scales = img_scale if isinstance(img_scale,
list) else [img_scale]
assert mmcv.is_list_of(self.img_scales, tuple)
# color channel order and normalize configs
self.img_norm_cfg = img_norm_cfg
# proposals
# TODO: revise _filter_imgs to be more flexible
if proposal_file is not None:
self.proposals = mmcv.load(proposal_file)
ori_ids = self.coco.getImgIds()
sorted_idx = [ori_ids.index(id) for id in self.img_ids]
self.proposals = [self.proposals[idx] for idx in sorted_idx]
else:
self.proposals = None
self.num_max_proposals = num_max_proposals
# flip ratio
self.flip_ratio = flip_ratio
assert flip_ratio >= 0 and flip_ratio <= 1
# padding border to ensure the image size can be divided by
# size_divisor (used for FPN)
self.size_divisor = size_divisor
# with crowd or not, False when using RetinaNet
self.with_crowd = with_crowd
# with mask or not
self.with_mask = with_mask
# with label is False for RPN
self.with_label = with_label
# in test mode or not
self.test_mode = test_mode
# debug mode or not
self.debug = debug
# set group flag for the sampler
self._set_group_flag()
# transforms
self.img_transform = ImageTransform(
size_divisor=self.size_divisor, **self.img_norm_cfg)
self.bbox_transform = BboxTransform()
self.mask_transform = MaskTransform()
self.numpy2tensor = Numpy2Tensor()
def __len__(self):
return len(self.img_ids)
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
img_ids = list(set([_['image_id'] for _ in self.coco.anns.values()]))
valid_ids = []
img_infos = []
for i in img_ids:
info = self.coco.loadImgs(i)[0]
if min(info['width'], info['height']) >= min_size:
valid_ids.append(i)
img_infos.append(info)
return valid_ids, img_infos
def _load_ann_info(self, idx):
img_id = self.img_ids[idx]
ann_ids = self.coco.getAnnIds(imgIds=img_id)
ann_info = self.coco.loadAnns(ann_ids)
return ann_info
def _parse_ann_info(self, ann_info, with_mask=True):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, mask_polys, poly_lens.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
# Two formats are provided.
# 1. mask: a binary map of the same size of the image.
# 2. polys: each mask consists of one or several polys, each poly is a
# list of float.
if with_mask:
gt_masks = []
gt_mask_polys = []
gt_poly_lens = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
if ann['iscrowd']:
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
if with_mask:
gt_masks.append(self.coco.annToMask(ann))
mask_polys = [
p for p in ann['segmentation'] if len(p) >= 6
] # valid polygons have >= 3 points (6 coordinates)
poly_lens = [len(p) for p in mask_polys]
gt_mask_polys.append(mask_polys)
gt_poly_lens.extend(poly_lens)
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)
if with_mask:
ann['masks'] = gt_masks
# poly format is not used in the current implementation
ann['mask_polys'] = gt_mask_polys
ann['poly_lens'] = gt_poly_lens
return ann
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
self.flag = np.zeros(len(self.img_ids), dtype=np.uint8)
for i in range(len(self.img_ids)):
img_info = self.img_infos[i]
if img_info['width'] / img_info['height'] > 1:
self.flag[i] = 1
def _rand_another(self, idx):
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_img(idx)
while True:
img_info = self.img_infos[idx]
ann_info = self._load_ann_info(idx)
# load image
img = mmcv.imread(osp.join(self.img_prefix, img_info['file_name']))
if self.debug:
show_ann(self.coco, img, ann_info)
# load proposals if necessary
if self.proposals is not None:
proposals = self.proposals[idx][:self.num_max_proposals, :4]
# TODO: Handle empty proposals properly. Currently images with
# no proposals are just ignored, but they can be used for
# training in concept.
if len(proposals) == 0:
idx = self._rand_another(idx)
continue
ann = self._parse_ann_info(ann_info, self.with_mask)
gt_bboxes = ann['bboxes']
gt_labels = ann['labels']
gt_bboxes_ignore = ann['bboxes_ignore']
# skip the image if there is no valid gt bbox
if len(gt_bboxes) == 0:
idx = self._rand_another(idx)
continue
# apply transforms
flip = True if np.random.rand() < self.flip_ratio else False
img_scale = random_scale(self.img_scales) # sample a scale
img, img_shape, pad_shape, scale_factor = self.img_transform(
img, img_scale, flip)
if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape,
scale_factor, flip)
gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
flip)
gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
scale_factor, flip)
if self.with_mask:
gt_masks = self.mask_transform(ann['masks'], pad_shape,
scale_factor, flip)
ori_shape = (img_info['height'], img_info['width'], 3)
img_meta = dict(
ori_shape=ori_shape,
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
flip=flip)
data = dict(
img=DC(to_tensor(img), stack=True),
img_meta=DC(img_meta, cpu_only=True),
gt_bboxes=DC(to_tensor(gt_bboxes)))
if self.proposals is not None:
data['proposals'] = DC(to_tensor(proposals))
if self.with_label:
data['gt_labels'] = DC(to_tensor(gt_labels))
if self.with_crowd:
data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
if self.with_mask:
data['gt_masks'] = DC(gt_masks, cpu_only=True)
return data
def prepare_test_img(self, idx):
"""Prepare an image for testing (multi-scale and flipping)"""
img_info = self.img_infos[idx]
img = mmcv.imread(osp.join(self.img_prefix, img_info['file_name']))
proposal = (self.proposals[idx][:, :4]
if self.proposals is not None else None)
def prepare_single(img, scale, flip, proposal=None):
_img, img_shape, pad_shape, scale_factor = self.img_transform(
img, scale, flip)
_img = to_tensor(_img)
_img_meta = dict(
ori_shape=(img_info['height'], img_info['width'], 3),
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
flip=flip)
if proposal is not None:
_proposal = self.bbox_transform(proposal, img_shape,
scale_factor, flip)
_proposal = to_tensor(_proposal)
else:
_proposal = None
return _img, _img_meta, _proposal
imgs = []
img_metas = []
proposals = []
for scale in self.img_scales:
_img, _img_meta, _proposal = prepare_single(
img, scale, False, proposal)
imgs.append(_img)
img_metas.append(DC(_img_meta, cpu_only=True))
proposals.append(_proposal)
if self.flip_ratio > 0:
_img, _img_meta, _proposal = prepare_single(
img, scale, True, proposal)
imgs.append(_img)
img_metas.append(DC(_img_meta, cpu_only=True))
proposals.append(_proposal)
data = dict(img=imgs, img_meta=img_metas)
if self.proposals is not None:
data['proposals'] = proposals
return data
from .build_loader import build_dataloader
from .sampler import GroupSampler, DistributedGroupSampler
__all__ = [
'GroupSampler', 'DistributedGroupSampler', 'build_dataloader'
]
from functools import partial
from mmcv.runner import get_dist_info
from mmcv.parallel import collate
from torch.utils.data import DataLoader
from .sampler import GroupSampler, DistributedGroupSampler
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
**kwargs):
if dist:
rank, world_size = get_dist_info()
sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size,
rank)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, imgs_per_gpu)
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
if not kwargs.get('shuffle', True):
sampler = None
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
**kwargs)
return data_loader
from __future__ import division
import math
import torch
import numpy as np
from torch.distributed import get_world_size, get_rank
from torch.utils.data.sampler import Sampler
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
assert hasattr(dataset, 'flag')
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64)
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes):
if size == 0:
continue
indice = np.where(self.flag == i)[0]
assert len(indice) == size
np.random.shuffle(indice)
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice)
indice = np.concatenate([indice, indice[:num_extra]])
indices.append(indice)
indices = np.concatenate(indices)
indices = [
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu))
]
indices = np.concatenate(indices)
indices = torch.from_numpy(indices).long()
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None):
if num_replicas is None:
num_replicas = get_world_size()
if rank is None:
rank = get_rank()
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert hasattr(self.dataset, 'flag')
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += int(
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
self.num_replicas)) * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
indice = indice[list(torch.randperm(int(size),
generator=g))].tolist()
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas)
) * self.samples_per_gpu * self.num_replicas - len(indice)
indice += indice[:extra]
indices += indice
assert len(indices) == self.total_size
indices = [
indices[j] for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g))
for j in range(i * self.samples_per_gpu, (i + 1) *
self.samples_per_gpu)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
import mmcv
import numpy as np
import torch
__all__ = ['ImageTransform', 'BboxTransform', 'MaskTransform', 'Numpy2Tensor']
class ImageTransform(object):
"""Preprocess an image.
1. rescale the image to expected size
2. normalize the image
3. flip the image (if needed)
4. pad the image (if needed)
5. transpose to (c, h, w)
"""
def __init__(self,
mean=(0, 0, 0),
std=(1, 1, 1),
to_rgb=True,
size_divisor=None):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
self.size_divisor = size_divisor
def __call__(self, img, scale, flip=False):
img, scale_factor = mmcv.imrescale(img, scale, return_scale=True)
img_shape = img.shape
img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
if flip:
img = mmcv.imflip(img)
if self.size_divisor is not None:
img = mmcv.impad_to_multiple(img, self.size_divisor)
pad_shape = img.shape
else:
pad_shape = img_shape
img = img.transpose(2, 0, 1)
return img, img_shape, pad_shape, scale_factor
def bbox_flip(bboxes, img_shape):
"""Flip bboxes horizontally.
Args:
bboxes(ndarray): shape (..., 4*k)
img_shape(tuple): (height, width)
"""
assert bboxes.shape[-1] % 4 == 0
w = img_shape[1]
flipped = bboxes.copy()
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
return flipped
class BboxTransform(object):
"""Preprocess gt bboxes.
1. rescale bboxes according to image size
2. flip bboxes (if needed)
3. pad the first dimension to `max_num_gts`
"""
def __init__(self, max_num_gts=None):
self.max_num_gts = max_num_gts
def __call__(self, bboxes, img_shape, scale_factor, flip=False):
gt_bboxes = bboxes * scale_factor
if flip:
gt_bboxes = bbox_flip(gt_bboxes, img_shape)
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1])
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0])
if self.max_num_gts is None:
return gt_bboxes
else:
num_gts = gt_bboxes.shape[0]
padded_bboxes = np.zeros((self.max_num_gts, 4), dtype=np.float32)
padded_bboxes[:num_gts, :] = gt_bboxes
return padded_bboxes
class MaskTransform(object):
"""Preprocess masks.
1. resize masks to expected size and stack to a single array
2. flip the masks (if needed)
3. pad the masks (if needed)
"""
def __call__(self, masks, pad_shape, scale_factor, flip=False):
masks = [
mmcv.imrescale(mask, scale_factor, interpolation='nearest')
for mask in masks
]
if flip:
masks = [mask[:, ::-1] for mask in masks]
padded_masks = [
mmcv.impad(mask, pad_shape[:2], pad_val=0) for mask in masks
]
padded_masks = np.stack(padded_masks, axis=0)
return padded_masks
class Numpy2Tensor(object):
def __init__(self):
pass
def __call__(self, *args):
if len(args) == 1:
return torch.from_numpy(args[0])
else:
return tuple([torch.from_numpy(np.array(array)) for array in args])
from collections import Sequence
import mmcv
import torch
import matplotlib.pyplot as plt
import numpy as np
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError('type {} cannot be converted to tensor.'.format(
type(data)))
def random_scale(img_scales, mode='range'):
"""Randomly select a scale from a list of scales or scale ranges.
Args:
img_scales (list[tuple]): Image scale or scale range.
mode (str): "range" or "value".
Returns:
tuple: Sampled image scale.
"""
num_scales = len(img_scales)
if num_scales == 1: # fixed scale is specified
img_scale = img_scales[0]
elif num_scales == 2: # randomly sample a scale
if mode == 'range':
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
elif mode == 'value':
img_scale = img_scales[np.random.randint(num_scales)]
else:
if mode != 'value':
raise ValueError(
'Only "value" mode supports more than 2 image scales')
img_scale = img_scales[np.random.randint(num_scales)]
return img_scale
def show_ann(coco, img, ann_info):
plt.imshow(mmcv.bgr2rgb(img))
plt.axis('off')
coco.showAnns(ann_info)
plt.show()
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN,
FasterRCNN, MaskRCNN)
from .builder import (build_neck, build_rpn_head, build_roi_extractor,
build_bbox_head, build_mask_head, build_detector)
__all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_rpn_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'build_detector'
]
from .resnet import ResNet
__all__ = ['ResNet']
import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
assert not with_cp
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
style='pytorch',
with_cp=False):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
inplanes,
planes,
stride,
dilation,
downsample,
style=style,
with_cp=with_cp))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
return nn.Sequential(*layers)
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices
self.style = style
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
stride = strides[i]
dilation = dilations[i]
planes = 64 * 2**i
res_layer = make_res_layer(
block,
self.inplanes,
planes,
num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for param in self.bn1.parameters():
param.requires_grad = False
self.bn1.eval()
self.bn1.weight.requires_grad = False
self.bn1.bias.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False
from .bbox_head import BBoxHead
from .convfc_bbox_head import ConvFCRoIHead, SharedFCRoIHead
__all__ = ['BBoxHead', 'ConvFCRoIHead', 'SharedFCRoIHead']
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
class BBoxHead(nn.Module):
"""Simplest RoI head, with only two fc layers for classification and
regression respectively"""
def __init__(self,
with_avg_pool=False,
with_cls=True,
with_reg=True,
roi_feat_size=7,
in_channels=256,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False):
super(BBoxHead, self).__init__()
assert with_cls or with_reg
self.with_avg_pool = with_avg_pool
self.with_cls = with_cls
self.with_reg = with_reg
self.roi_feat_size = roi_feat_size
self.in_channels = in_channels
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.reg_class_agnostic = reg_class_agnostic
in_channels = self.in_channels
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(roi_feat_size)
else:
in_channels *= (self.roi_feat_size * self.roi_feat_size)
if self.with_cls:
self.fc_cls = nn.Linear(in_channels, num_classes)
if self.with_reg:
out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
self.fc_reg = nn.Linear(in_channels, out_dim_reg)
self.debug_imgs = None
def init_weights(self):
if self.with_cls:
nn.init.normal_(self.fc_cls.weight, 0, 0.01)
nn.init.constant_(self.fc_cls.bias, 0)
if self.with_reg:
nn.init.normal_(self.fc_reg.weight, 0, 0.001)
nn.init.constant_(self.fc_reg.bias, 0)
def forward(self, x):
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
cls_score = self.fc_cls(x) if self.with_cls else None
bbox_pred = self.fc_reg(x) if self.with_reg else None
return cls_score, bbox_pred
def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes,
pos_gt_labels, rcnn_train_cfg):
reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes
cls_reg_targets = bbox_target(
pos_proposals,
neg_proposals,
pos_gt_bboxes,
pos_gt_labels,
rcnn_train_cfg,
reg_num_classes,
target_means=self.target_means,
target_stds=self.target_stds)
return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
cls_score, labels, label_weights)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
losses['loss_reg'] = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=bbox_targets.size(0))
return losses
def get_det_bboxes(self,
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
nms_cfg=None):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
if bbox_pred is not None:
bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
self.target_stds, img_shape)
else:
bboxes = rois[:, 1:]
# TODO: add clip here
if rescale:
bboxes /= scale_factor
if nms_cfg is None:
return bboxes, scores
else:
det_bboxes, det_labels = multiclass_nms(
bboxes, scores, nms_cfg.score_thr, nms_cfg.nms_thr,
nms_cfg.max_per_img)
return det_bboxes, det_labels
import torch.nn as nn
from .bbox_head import BBoxHead
from ..utils import ConvModule
class ConvFCRoIHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
/-> cls convs -> cls fcs -> cls
shared convs -> shared fcs
\-> reg convs -> reg fcs -> reg
"""
def __init__(self,
num_shared_convs=0,
num_shared_fcs=0,
num_cls_convs=0,
num_cls_fcs=0,
num_reg_convs=0,
num_reg_fcs=0,
conv_out_channels=256,
fc_out_channels=1024,
*args,
**kwargs):
super(ConvFCRoIHead, self).__init__(*args, **kwargs)
assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs
+ num_reg_convs + num_reg_fcs > 0)
if num_cls_convs > 0 or num_reg_convs > 0:
assert num_shared_fcs == 0
if not self.with_cls:
assert num_cls_convs == 0 and num_cls_fcs == 0
if not self.with_reg:
assert num_reg_convs == 0 and num_reg_fcs == 0
self.num_shared_convs = num_shared_convs
self.num_shared_fcs = num_shared_fcs
self.num_cls_convs = num_cls_convs
self.num_cls_fcs = num_cls_fcs
self.num_reg_convs = num_reg_convs
self.num_reg_fcs = num_reg_fcs
self.conv_out_channels = conv_out_channels
self.fc_out_channels = fc_out_channels
# add shared convs and fcs
self.shared_convs, self.shared_fcs, last_layer_dim = \
self._add_conv_fc_branch(
self.num_shared_convs, self.num_shared_fcs, self.in_channels,
True)
self.shared_out_channels = last_layer_dim
# add cls specific branch
self.cls_convs, self.cls_fcs, self.cls_last_dim = \
self._add_conv_fc_branch(
self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
# add reg specific branch
self.reg_convs, self.reg_fcs, self.reg_last_dim = \
self._add_conv_fc_branch(
self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
if self.num_shared_fcs == 0 and not self.with_avg_pool:
if self.num_cls_fcs == 0:
self.cls_last_dim *= (self.roi_feat_size * self.roi_feat_size)
if self.num_reg_fcs == 0:
self.reg_last_dim *= (self.roi_feat_size * self.roi_feat_size)
self.relu = nn.ReLU(inplace=True)
# reconstruct fc_cls and fc_reg since input channels are changed
if self.with_cls:
self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes)
if self.with_reg:
out_dim_reg = (4 if self.reg_class_agnostic else
4 * self.num_classes)
self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)
def _add_conv_fc_branch(self,
num_branch_convs,
num_branch_fcs,
in_channels,
is_shared=False):
"""Add shared or separable branch
convs -> avg pool (optional) -> fcs
"""
last_layer_dim = in_channels
# add branch specific conv layers
branch_convs = nn.ModuleList()
if num_branch_convs > 0:
for i in range(num_branch_convs):
conv_in_channels = (last_layer_dim
if i == 0 else self.conv_out_channels)
branch_convs.append(
ConvModule(
conv_in_channels,
self.conv_out_channels,
3,
padding=1,
normalize=self.normalize,
bias=self.with_bias))
last_layer_dim = self.conv_out_channels
# add branch specific fc layers
branch_fcs = nn.ModuleList()
if num_branch_fcs > 0:
# for shared branch, only consider self.with_avg_pool
# for separated branches, also consider self.num_shared_fcs
if (is_shared
or self.num_shared_fcs == 0) and not self.with_avg_pool:
last_layer_dim *= (self.roi_feat_size * self.roi_feat_size)
for i in range(num_branch_fcs):
fc_in_channels = (last_layer_dim
if i == 0 else self.fc_out_channels)
branch_fcs.append(
nn.Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels
return branch_convs, branch_fcs, last_layer_dim
def init_weights(self):
super(ConvFCRoIHead, self).init_weights()
for module_list in [self.shared_fcs, self.cls_fcs, self.reg_fcs]:
for m in module_list.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# shared part
if self.num_shared_convs > 0:
for conv in self.shared_convs:
x = conv(x)
if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
for fc in self.shared_fcs:
x = self.relu(fc(x))
# separate branches
x_cls = x
x_reg = x
for conv in self.cls_convs:
x_cls = conv(x_cls)
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.view(x_cls.size(0), -1)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))
for conv in self.reg_convs:
x_reg = conv(x_reg)
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.view(x_reg.size(0), -1)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))
cls_score = self.fc_cls(x_cls) if self.with_cls else None
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
return cls_score, bbox_pred
class SharedFCRoIHead(ConvFCRoIHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
assert num_fcs >= 1
super(SharedFCRoIHead, self).__init__(
num_shared_convs=0,
num_shared_fcs=num_fcs,
num_cls_convs=0,
num_cls_fcs=0,
num_reg_convs=0,
num_reg_fcs=0,
fc_out_channels=fc_out_channels,
*args,
**kwargs)
from mmcv.runner import obj_from_dict
from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
mask_heads)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head', 'build_detector'
]
def _build_module(cfg, parrent=None, default_args=None):
return cfg if isinstance(cfg, nn.Module) else obj_from_dict(
cfg, parrent, default_args)
def build(cfg, parrent=None, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return _build_module(cfg, parrent, default_args)
def build_backbone(cfg):
return build(cfg, backbones)
def build_neck(cfg):
return build(cfg, necks)
def build_rpn_head(cfg):
return build(cfg, rpn_heads)
def build_roi_extractor(cfg):
return build(cfg, roi_extractors)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .base import BaseDetector
from .two_stage import TwoStageDetector
from .rpn import RPN
from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
__all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN'
]
import logging
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch.nn as nn
from mmdet.core import tensor2imgs, get_classes
class BaseDetector(nn.Module):
"""Base class for detectors"""
__metaclass__ = ABCMeta
def __init__(self):
super(BaseDetector, self).__init__()
@property
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None
@property
def with_bbox(self):
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_mask(self):
return hasattr(self, 'mask_head') and self.mask_head is not None
@abstractmethod
def extract_feat(self, imgs):
pass
def extract_feats(self, imgs):
assert isinstance(imgs, list)
for img in imgs:
yield self.extract_feat(img)
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
pass
def init_weights(self, pretrained=None):
if pretrained is not None:
logger = logging.getLogger()
logger.info('load model from: {}'.format(pretrained))
def forward_test(self, imgs, img_metas, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
def forward(self, img, img_meta, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
def show_result(self,
data,
result,
img_norm_cfg,
dataset='coco',
score_thr=0.3):
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
if isinstance(dataset, str):
class_names = get_classes(dataset)
elif isinstance(dataset, list):
class_names = dataset
else:
raise TypeError('dataset must be a valid dataset name or a list'
' of class names, not {}'.format(type(dataset)))
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
mmcv.imshow_det_bboxes(
img_show,
bboxes,
labels,
class_names=class_names,
score_thr=score_thr)
from .two_stage import TwoStageDetector
class FastRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
mask_roi_extractor=None,
mask_head=None,
pretrained=None):
super(FastRCNN, self).__init__(
backbone=backbone,
neck=neck,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
pretrained=pretrained)
def forward_test(self, imgs, img_metas, proposals, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], proposals[0],
**kwargs)
else:
return self.aug_test(imgs, img_metas, proposals, **kwargs)
from .two_stage import TwoStageDetector
class FasterRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from .two_stage import TwoStageDetector
class MaskRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
mask_roi_extractor,
mask_head,
train_cfg,
test_cfg,
pretrained=None):
super(MaskRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def show_result(self, data, result, img_norm_cfg, **kwargs):
# TODO: show segmentation masks
assert isinstance(result, tuple)
assert len(result) == 2 # (bbox_results, segm_results)
super(MaskRCNN, self).show_result(data, result[0], img_norm_cfg,
**kwargs)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment