Commit 9118a94a authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Split SegResizeFlipPadRescale into different existing transforms (#1852)

* Split seg trans

* Modify cfg

* fix typo
parent b6966700
...@@ -221,7 +221,7 @@ train_pipeline = [ ...@@ -221,7 +221,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -217,7 +217,7 @@ train_pipeline = [ ...@@ -217,7 +217,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -205,7 +205,7 @@ train_pipeline = [ ...@@ -205,7 +205,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -205,7 +205,7 @@ train_pipeline = [ ...@@ -205,7 +205,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -205,7 +205,7 @@ train_pipeline = [ ...@@ -205,7 +205,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -207,7 +207,7 @@ train_pipeline = [ ...@@ -207,7 +207,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -207,7 +207,7 @@ train_pipeline = [ ...@@ -207,7 +207,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5), dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8), dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict( dict(
type='Collect', type='Collect',
......
...@@ -98,6 +98,7 @@ class CustomDataset(Dataset): ...@@ -98,6 +98,7 @@ class CustomDataset(Dataset):
results['proposal_file'] = self.proposal_file results['proposal_file'] = self.proposal_file
results['bbox_fields'] = [] results['bbox_fields'] = []
results['mask_fields'] = [] results['mask_fields'] = []
results['seg_fields'] = []
def _filter_imgs(self, min_size=32): def _filter_imgs(self, min_size=32):
"""Filter images too small.""" """Filter images too small."""
......
...@@ -5,12 +5,12 @@ from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals ...@@ -5,12 +5,12 @@ from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
from .test_aug import MultiScaleFlipAug from .test_aug import MultiScaleFlipAug
from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad, from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip, Resize, PhotoMetricDistortion, RandomCrop, RandomFlip, Resize,
SegResizeFlipPadRescale) SegRescale)
__all__ = [ __all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad',
'RandomCrop', 'Normalize', 'SegResizeFlipPadRescale', 'MinIoURandomCrop', 'RandomCrop', 'Normalize', 'SegRescale', 'MinIoURandomCrop', 'Expand',
'Expand', 'PhotoMetricDistortion', 'Albu' 'PhotoMetricDistortion', 'Albu'
] ]
...@@ -91,6 +91,7 @@ class LoadAnnotations(object): ...@@ -91,6 +91,7 @@ class LoadAnnotations(object):
results['gt_semantic_seg'] = mmcv.imread( results['gt_semantic_seg'] = mmcv.imread(
osp.join(results['seg_prefix'], results['ann_info']['seg_map']), osp.join(results['seg_prefix'], results['ann_info']['seg_map']),
flag='unchanged').squeeze() flag='unchanged').squeeze()
results['seg_fields'].append('gt_semantic_seg')
return results return results
def __call__(self, results): def __call__(self, results):
......
...@@ -149,12 +149,23 @@ class Resize(object): ...@@ -149,12 +149,23 @@ class Resize(object):
] ]
results[key] = masks results[key] = masks
def _resize_seg(self, results):
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
results['gt_semantic_seg'] = gt_seg
def __call__(self, results): def __call__(self, results):
if 'scale' not in results: if 'scale' not in results:
self._random_scale(results) self._random_scale(results)
self._resize_img(results) self._resize_img(results)
self._resize_bboxes(results) self._resize_bboxes(results)
self._resize_masks(results) self._resize_masks(results)
self._resize_seg(results)
return results return results
def __repr__(self): def __repr__(self):
...@@ -229,6 +240,11 @@ class RandomFlip(object): ...@@ -229,6 +240,11 @@ class RandomFlip(object):
mmcv.imflip(mask, direction=results['flip_direction']) mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key] for mask in results[key]
] ]
# flip segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction'])
return results return results
def __repr__(self): def __repr__(self):
...@@ -280,9 +296,14 @@ class Pad(object): ...@@ -280,9 +296,14 @@ class Pad(object):
else: else:
results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8) results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8)
def _pad_seg(self, results):
for key in results.get('seg_fields', []):
results[key] = mmcv.impad(results[key], results['pad_shape'][:2])
def __call__(self, results): def __call__(self, results):
self._pad_img(results) self._pad_img(results)
self._pad_masks(results) self._pad_masks(results)
self._pad_seg(results)
return results return results
def __repr__(self): def __repr__(self):
...@@ -386,15 +407,8 @@ class RandomCrop(object): ...@@ -386,15 +407,8 @@ class RandomCrop(object):
@PIPELINES.register_module @PIPELINES.register_module
class SegResizeFlipPadRescale(object): class SegRescale(object):
"""A sequential transforms to semantic segmentation maps. """Rescale semantic segmentation maps.
The same pipeline as input images is applied to the semantic segmentation
map, and finally rescale it by some scale factor. The transforms include:
1. resize
2. flip
3. pad
4. rescale (so that the final size can be different from the image size)
Args: Args:
scale_factor (float): The scale factor of the final output. scale_factor (float): The scale factor of the final output.
...@@ -404,24 +418,10 @@ class SegResizeFlipPadRescale(object): ...@@ -404,24 +418,10 @@ class SegResizeFlipPadRescale(object):
self.scale_factor = scale_factor self.scale_factor = scale_factor
def __call__(self, results): def __call__(self, results):
if results['keep_ratio']: for key in results.get('seg_fields', []):
gt_seg = mmcv.imrescale( if self.scale_factor != 1:
results['gt_semantic_seg'], results[key] = mmcv.imrescale(
results['scale'], results[key], self.scale_factor, interpolation='nearest')
interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results['gt_semantic_seg'],
results['scale'],
interpolation='nearest')
if results['flip']:
gt_seg = mmcv.imflip(gt_seg)
if gt_seg.shape != results['pad_shape']:
gt_seg = mmcv.impad(gt_seg, results['pad_shape'][:2])
if self.scale_factor != 1:
gt_seg = mmcv.imrescale(
gt_seg, self.scale_factor, interpolation='nearest')
results['gt_semantic_seg'] = gt_seg
return results return results
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment