Commit 9118a94a authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Split SegResizeFlipPadRescale into different existing transforms (#1852)

* Split seg trans

* Modify cfg

* fix typo
parent b6966700
......@@ -221,7 +221,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -217,7 +217,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -205,7 +205,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -205,7 +205,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -205,7 +205,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -207,7 +207,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -207,7 +207,7 @@ train_pipeline = [
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
dict(type='SegRescale', scale_factor=1 / 8),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
......
......@@ -98,6 +98,7 @@ class CustomDataset(Dataset):
results['proposal_file'] = self.proposal_file
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
def _filter_imgs(self, min_size=32):
"""Filter images too small."""
......
......@@ -5,12 +5,12 @@ from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
from .test_aug import MultiScaleFlipAug
from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip, Resize,
SegResizeFlipPadRescale)
SegRescale)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad',
'RandomCrop', 'Normalize', 'SegResizeFlipPadRescale', 'MinIoURandomCrop',
'Expand', 'PhotoMetricDistortion', 'Albu'
'RandomCrop', 'Normalize', 'SegRescale', 'MinIoURandomCrop', 'Expand',
'PhotoMetricDistortion', 'Albu'
]
......@@ -91,6 +91,7 @@ class LoadAnnotations(object):
results['gt_semantic_seg'] = mmcv.imread(
osp.join(results['seg_prefix'], results['ann_info']['seg_map']),
flag='unchanged').squeeze()
results['seg_fields'].append('gt_semantic_seg')
return results
def __call__(self, results):
......
......@@ -149,12 +149,23 @@ class Resize(object):
]
results[key] = masks
def _resize_seg(self, results):
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
results['gt_semantic_seg'] = gt_seg
def __call__(self, results):
if 'scale' not in results:
self._random_scale(results)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_masks(results)
self._resize_seg(results)
return results
def __repr__(self):
......@@ -229,6 +240,11 @@ class RandomFlip(object):
mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key]
]
# flip segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction'])
return results
def __repr__(self):
......@@ -280,9 +296,14 @@ class Pad(object):
else:
results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8)
def _pad_seg(self, results):
for key in results.get('seg_fields', []):
results[key] = mmcv.impad(results[key], results['pad_shape'][:2])
def __call__(self, results):
self._pad_img(results)
self._pad_masks(results)
self._pad_seg(results)
return results
def __repr__(self):
......@@ -386,15 +407,8 @@ class RandomCrop(object):
@PIPELINES.register_module
class SegResizeFlipPadRescale(object):
"""A sequential transforms to semantic segmentation maps.
The same pipeline as input images is applied to the semantic segmentation
map, and finally rescale it by some scale factor. The transforms include:
1. resize
2. flip
3. pad
4. rescale (so that the final size can be different from the image size)
class SegRescale(object):
"""Rescale semantic segmentation maps.
Args:
scale_factor (float): The scale factor of the final output.
......@@ -404,24 +418,10 @@ class SegResizeFlipPadRescale(object):
self.scale_factor = scale_factor
def __call__(self, results):
if results['keep_ratio']:
gt_seg = mmcv.imrescale(
results['gt_semantic_seg'],
results['scale'],
interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results['gt_semantic_seg'],
results['scale'],
interpolation='nearest')
if results['flip']:
gt_seg = mmcv.imflip(gt_seg)
if gt_seg.shape != results['pad_shape']:
gt_seg = mmcv.impad(gt_seg, results['pad_shape'][:2])
if self.scale_factor != 1:
gt_seg = mmcv.imrescale(
gt_seg, self.scale_factor, interpolation='nearest')
results['gt_semantic_seg'] = gt_seg
for key in results.get('seg_fields', []):
if self.scale_factor != 1:
results[key] = mmcv.imrescale(
results[key], self.scale_factor, interpolation='nearest')
return results
def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment