Commit d88ad8fc authored by tyomj's avatar tyomj Committed by Kai Chen
Browse files

Albumentations augs wrapper (#1354)

* Albumentations wrapper

* 2 single quote format

* version >= 0.3.2
parent 6ee5e4d6
# model settings
model = dict(
type='MaskRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
albu_train_transforms = [
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(
type='OneOf',
transforms=[
dict(
type='RGBShift',
r_shift_limit=10,
g_shift_limit=10,
b_shift_limit=10,
p=1.0),
dict(
type='HueSaturationValue',
hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=20,
p=1.0)
],
p=0.1),
dict(type='JpegCompression', quality_lower=85, quality_upper=95, p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='Pad', size_divisor=32),
dict(
type='Albu',
transforms=albu_train_transforms,
bbox_params=dict(
type='BboxParams',
format='pascal_voc',
label_fields=['gt_labels'],
min_visibility=0.0,
filter_lost_elements=True),
keymap={
'img': 'image',
'gt_masks': 'masks',
'gt_bboxes': 'bboxes'
},
update_pad_shape=False,
skip_img_without_anno=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'],
meta_keys=('filename', 'ori_shape', 'img_shape', 'img_norm_cfg',
'pad_shape', 'scale_factor'))
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
evaluation = dict(interval=1)
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mask_rcnn_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -3,7 +3,7 @@ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, ...@@ -3,7 +3,7 @@ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
Transpose, to_tensor) Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
from .test_aug import MultiScaleFlipAug from .test_aug import MultiScaleFlipAug
from .transforms import (Expand, MinIoURandomCrop, Normalize, Pad, from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip, Resize, PhotoMetricDistortion, RandomCrop, RandomFlip, Resize,
SegResizeFlipPadRescale) SegResizeFlipPadRescale)
...@@ -12,5 +12,5 @@ __all__ = [ ...@@ -12,5 +12,5 @@ __all__ = [
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad',
'RandomCrop', 'Normalize', 'SegResizeFlipPadRescale', 'MinIoURandomCrop', 'RandomCrop', 'Normalize', 'SegResizeFlipPadRescale', 'MinIoURandomCrop',
'Expand', 'PhotoMetricDistortion' 'Expand', 'PhotoMetricDistortion', 'Albu'
] ]
import inspect
import albumentations
import mmcv import mmcv
import numpy as np import numpy as np
from albumentations import Compose
from imagecorruptions import corrupt from imagecorruptions import corrupt
from numpy import random from numpy import random
...@@ -596,9 +600,8 @@ class MinIoURandomCrop(object): ...@@ -596,9 +600,8 @@ class MinIoURandomCrop(object):
# center of boxes should inside the crop img # center of boxes should inside the crop img
center = (boxes[:, :2] + boxes[:, 2:]) / 2 center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = (center[:, 0] > patch[0]) * ( mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) *
center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * ( (center[:, 0] < patch[2]) * (center[:, 1] < patch[3]))
center[:, 1] < patch[3])
if not mask.any(): if not mask.any():
continue continue
boxes = boxes[mask] boxes = boxes[mask]
...@@ -651,3 +654,155 @@ class Corrupt(object): ...@@ -651,3 +654,155 @@ class Corrupt(object):
repr_str += '(corruption={}, severity={})'.format( repr_str += '(corruption={}, severity={})'.format(
self.corruption, self.severity) self.corruption, self.severity)
return repr_str return repr_str
@PIPELINES.register_module
class Albu(object):
def __init__(self,
transforms,
bbox_params=None,
keymap=None,
update_pad_shape=False,
skip_img_without_anno=False):
"""
Adds custom transformations from Albumentations lib.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
transforms (list): list of albu transformations
bbox_params (dict): bbox_params for albumentation `Compose`
keymap (dict): contains {'input key':'albumentation-style key'}
skip_img_without_anno (bool): whether to skip the image
if no ann left after aug
"""
self.transforms = transforms
self.filter_lost_elements = False
self.update_pad_shape = update_pad_shape
self.skip_img_without_anno = skip_img_without_anno
# A simple workaround to remove masks without boxes
if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
and 'filter_lost_elements' in bbox_params):
self.filter_lost_elements = True
self.origin_label_fields = bbox_params['label_fields']
bbox_params['label_fields'] = ['idx_mapper']
del bbox_params['filter_lost_elements']
self.bbox_params = (
self.albu_builder(bbox_params) if bbox_params else None)
self.aug = Compose([self.albu_builder(t) for t in self.transforms],
bbox_params=self.bbox_params)
if not keymap:
self.keymap_to_albu = {
'img': 'image',
'gt_masks': 'masks',
'gt_bboxes': 'bboxes'
}
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg):
"""Import a module from albumentations.
Inherits some of `build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and "type" in cfg
args = cfg.copy()
obj_type = args.pop("type")
if mmcv.is_str(obj_type):
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
'type must be a str or valid type, but got {}'.format(
type(obj_type)))
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d, keymap):
"""
Dictionary mapper.
Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict
def __call__(self, results):
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
if 'bboxes' in results:
# to list of boxes
if isinstance(results['bboxes'], np.ndarray):
results['bboxes'] = [x for x in results['bboxes']]
# add pseudo-field for filtration
if self.filter_lost_elements:
results['idx_mapper'] = np.arange(len(results['bboxes']))
results = self.aug(**results)
if 'bboxes' in results:
if isinstance(results['bboxes'], list):
results['bboxes'] = np.array(
results['bboxes'], dtype=np.float32)
# filter label_fields
if self.filter_lost_elements:
results['idx_mapper'] = np.arange(len(results['bboxes']))
for label in self.origin_label_fields:
results[label] = np.array(
[results[label][i] for i in results['idx_mapper']])
if 'masks' in results:
results['masks'] = [
results['masks'][i] for i in results['idx_mapper']
]
if (not len(results['idx_mapper'])
and self.skip_img_without_anno):
return None
if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
# back to the original format
results = self.mapper(results, self.keymap_back)
# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(transformations={})'.format(self.transformations)
return repr_str
...@@ -7,3 +7,4 @@ pycocotools ...@@ -7,3 +7,4 @@ pycocotools
torch>=1.1 torch>=1.1
torchvision torchvision
imagecorruptions imagecorruptions
albumentations>=0.3.2
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment