Commit 76168f9c authored by ThangVu's avatar ThangVu
Browse files

resolve conflict GN-dev with master

parents 8a086f02 c5d8f002
from .custom import CustomDataset from .custom import CustomDataset
from .xml_style import XMLDataset
from .coco import CocoDataset from .coco import CocoDataset
from .voc import VOCDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset from .repeat_dataset import RepeatDataset
from .extra_aug import ExtraAugmentation
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
'get_dataset', 'ConcatDataset', 'RepeatDataset', 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset',
'ExtraAugmentation'
] ]
...@@ -6,6 +6,21 @@ from .custom import CustomDataset ...@@ -6,6 +6,21 @@ from .custom import CustomDataset
class CocoDataset(CustomDataset): class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
self.coco = COCO(ann_file) self.coco = COCO(ann_file)
self.cat_ids = self.coco.getCatIds() self.cat_ids = self.coco.getCatIds()
......
...@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset ...@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset): class ConcatDataset(_ConcatDataset):
""" """A wrapper of concatenated dataset.
Same as torch.utils.data.dataset.ConcatDataset, but
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio. concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
""" """
def __init__(self, datasets): def __init__(self, datasets):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets) super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
if hasattr(datasets[0], 'flag'): if hasattr(datasets[0], 'flag'):
flags = [] flags = []
for i in range(0, len(datasets)): for i in range(0, len(datasets)):
......
...@@ -8,6 +8,7 @@ from torch.utils.data import Dataset ...@@ -8,6 +8,7 @@ from torch.utils.data import Dataset
from .transforms import (ImageTransform, BboxTransform, MaskTransform, from .transforms import (ImageTransform, BboxTransform, MaskTransform,
Numpy2Tensor) Numpy2Tensor)
from .utils import to_tensor, random_scale from .utils import to_tensor, random_scale
from .extra_aug import ExtraAugmentation
class CustomDataset(Dataset): class CustomDataset(Dataset):
...@@ -32,6 +33,8 @@ class CustomDataset(Dataset): ...@@ -32,6 +33,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing. The `ann` field is optional for testing.
""" """
CLASSES = None
def __init__(self, def __init__(self,
ann_file, ann_file,
img_prefix, img_prefix,
...@@ -44,7 +47,12 @@ class CustomDataset(Dataset): ...@@ -44,7 +47,12 @@ class CustomDataset(Dataset):
with_mask=True, with_mask=True,
with_crowd=True, with_crowd=True,
with_label=True, with_label=True,
extra_aug=None,
resize_keep_ratio=True,
test_mode=False): test_mode=False):
# prefix of images path
self.img_prefix = img_prefix
# load annotations (and proposals) # load annotations (and proposals)
self.img_infos = self.load_annotations(ann_file) self.img_infos = self.load_annotations(ann_file)
if proposal_file is not None: if proposal_file is not None:
...@@ -58,8 +66,6 @@ class CustomDataset(Dataset): ...@@ -58,8 +66,6 @@ class CustomDataset(Dataset):
if self.proposals is not None: if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds] self.proposals = [self.proposals[i] for i in valid_inds]
# prefix of images path
self.img_prefix = img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...] # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self.img_scales = img_scale if isinstance(img_scale, self.img_scales = img_scale if isinstance(img_scale,
list) else [img_scale] list) else [img_scale]
...@@ -96,6 +102,15 @@ class CustomDataset(Dataset): ...@@ -96,6 +102,15 @@ class CustomDataset(Dataset):
self.mask_transform = MaskTransform() self.mask_transform = MaskTransform()
self.numpy2tensor = Numpy2Tensor() self.numpy2tensor = Numpy2Tensor()
# if use extra augmentation
if extra_aug is not None:
self.extra_aug = ExtraAugmentation(**extra_aug)
else:
self.extra_aug = None
# image rescale if keep ratio
self.resize_keep_ratio = resize_keep_ratio
def __len__(self): def __len__(self):
return len(self.img_infos) return len(self.img_infos)
...@@ -174,11 +189,17 @@ class CustomDataset(Dataset): ...@@ -174,11 +189,17 @@ class CustomDataset(Dataset):
if len(gt_bboxes) == 0: if len(gt_bboxes) == 0:
return None return None
# extra augmentation
if self.extra_aug is not None:
img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
gt_labels)
# apply transforms # apply transforms
flip = True if np.random.rand() < self.flip_ratio else False flip = True if np.random.rand() < self.flip_ratio else False
img_scale = random_scale(self.img_scales) # sample a scale img_scale = random_scale(self.img_scales) # sample a scale
img, img_shape, pad_shape, scale_factor = self.img_transform( img, img_shape, pad_shape, scale_factor = self.img_transform(
img, img_scale, flip) img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
img = img.copy()
if self.proposals is not None: if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, scale_factor, proposals = self.bbox_transform(proposals, img_shape, scale_factor,
flip) flip)
...@@ -230,7 +251,7 @@ class CustomDataset(Dataset): ...@@ -230,7 +251,7 @@ class CustomDataset(Dataset):
def prepare_single(img, scale, flip, proposal=None): def prepare_single(img, scale, flip, proposal=None):
_img, img_shape, pad_shape, scale_factor = self.img_transform( _img, img_shape, pad_shape, scale_factor = self.img_transform(
img, scale, flip) img, scale, flip, keep_ratio=self.resize_keep_ratio)
_img = to_tensor(_img) _img = to_tensor(_img)
_img_meta = dict( _img_meta = dict(
ori_shape=(img_info['height'], img_info['width'], 3), ori_shape=(img_info['height'], img_info['width'], 3),
......
import mmcv
import numpy as np
from numpy import random
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
class PhotoMetricDistortion(object):
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, img, boxes, labels):
# random brightness
if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
return img, boxes, labels
class Expand(object):
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
def __call__(self, img, boxes, labels):
if random.randint(2):
return img, boxes, labels
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
img = expand_img
boxes += np.tile((left, top), 2)
return img, boxes, labels
class RandomCrop(object):
def __init__(self,
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3):
# 1: return ori img
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
def __call__(self, img, boxes, labels):
h, w, c = img.shape
while True:
mode = random.choice(self.sample_mode)
if mode == 1:
return img, boxes, labels
min_iou = mode
for i in range(50):
new_w = random.uniform(self.min_crop_size * w, w)
new_h = random.uniform(self.min_crop_size * h, h)
# h / w in [0.5, 2]
if new_h / new_w < 0.5 or new_h / new_w > 2:
continue
left = random.uniform(w - new_w)
top = random.uniform(h - new_h)
patch = np.array((int(left), int(top), int(left + new_w),
int(top + new_h)))
overlaps = bbox_overlaps(
patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
if overlaps.min() < min_iou:
continue
# center of boxes should inside the crop img
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = (center[:, 0] > patch[0]) * (
center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * (
center[:, 1] < patch[3])
if not mask.any():
continue
boxes = boxes[mask]
labels = labels[mask]
# adjust boxes
img = img[patch[1]:patch[3], patch[0]:patch[2]]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)
return img, boxes, labels
class ExtraAugmentation(object):
def __init__(self,
photo_metric_distortion=None,
expand=None,
random_crop=None):
self.transforms = []
if photo_metric_distortion is not None:
self.transforms.append(
PhotoMetricDistortion(**photo_metric_distortion))
if expand is not None:
self.transforms.append(Expand(**expand))
if random_crop is not None:
self.transforms.append(RandomCrop(**random_crop))
def __call__(self, img, boxes, labels):
img = img.astype(np.float32)
for transform in self.transforms:
img, boxes, labels = transform(img, boxes, labels)
return img, boxes, labels
...@@ -6,12 +6,14 @@ class RepeatDataset(object): ...@@ -6,12 +6,14 @@ class RepeatDataset(object):
def __init__(self, dataset, times): def __init__(self, dataset, times):
self.dataset = dataset self.dataset = dataset
self.times = times self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'): if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times) self.flag = np.tile(self.dataset.flag, times)
self._original_length = len(self.dataset)
self._ori_len = len(self.dataset)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.dataset[idx % self._original_length] return self.dataset[idx % self._ori_len]
def __len__(self): def __len__(self):
return self.times * self._original_length return self.times * self._ori_len
...@@ -25,8 +25,14 @@ class ImageTransform(object): ...@@ -25,8 +25,14 @@ class ImageTransform(object):
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.size_divisor = size_divisor self.size_divisor = size_divisor
def __call__(self, img, scale, flip=False): def __call__(self, img, scale, flip=False, keep_ratio=True):
if keep_ratio:
img, scale_factor = mmcv.imrescale(img, scale, return_scale=True) img, scale_factor = mmcv.imrescale(img, scale, return_scale=True)
else:
img, w_scale, h_scale = mmcv.imresize(
img, scale, return_scale=True)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
img_shape = img.shape img_shape = img.shape
img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
if flip: if flip:
......
from .xml_style import XMLDataset
class VOCDataset(XMLDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
def __init__(self, **kwargs):
super(VOCDataset, self).__init__(**kwargs)
if 'VOC2007' in self.img_prefix:
self.year = 2007
elif 'VOC2012' in self.img_prefix:
self.year = 2012
else:
raise ValueError('Cannot infer dataset year from img_prefix')
This diff is collapsed.
from .resnet import ResNet from .resnet import ResNet
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
__all__ = ['ResNet'] __all__ = ['ResNet', 'ResNeXt', 'SSDVGG']
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .retina_head import RetinaHead from .retina_head import RetinaHead
from .ssd_head import SSDHead
__all__ = ['RetinaHead'] __all__ = ['RetinaHead', 'SSDHead']
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment