Commit 2619aa9c authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Feature] Add Part3 of data transform (#1735)

* update data transform part3

* update init

* rename flip funcs

* fix comments

* update comments

* fix lint

* Update mmcv/transforms/processing.py

* fix docs format

* fix comments

* add test pad_val and fix bugs in class Pad

* merge updated pad

* fix lint

* Update tests/test_transforms/test_transforms_processing.py
parent 5af6c12b
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import TRANSFORMS
from .loading import LoadAnnotation, LoadImageFromFile
from .processing import Normalize, Pad, RandomFlip, RandomResize, Resize
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize,
RandomResize, Resize)
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
try:
......@@ -10,7 +12,8 @@ except ImportError:
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'RandomFlip', 'RandomResize'
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
else:
from .formatting import ImageToTensor, ToTensor, to_tensor
......@@ -18,5 +21,7 @@ else:
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip', 'RandomResize'
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip',
'RandomMultiscaleResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize'
]
......@@ -16,9 +16,11 @@ def to_tensor(
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
Returns:
torch.Tensor: the converted data.
"""
......
This diff is collapsed.
......@@ -6,7 +6,18 @@ import numpy as np
import pytest
import mmcv
from mmcv.transforms import Normalize, Pad, RandomFlip, RandomResize, Resize
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
RandomResize, Resize)
try:
import torch
except ModuleNotFoundError:
torch = None
else:
import torchvision
from numpy.testing import assert_array_almost_equal, assert_array_equal
from PIL import Image
class TestNormalize:
......@@ -223,6 +234,416 @@ class TestPad:
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")
class TestCenterCrop:
@classmethod
def setup_class(cls):
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
cls.original_img = copy.deepcopy(img)
seg = np.random.randint(0, 19, (300, 400)).astype(np.uint8)
cls.gt_semantic_map = copy.deepcopy(seg)
@staticmethod
def reset_results(results, original_img, gt_semantic_map):
results['img'] = copy.deepcopy(original_img)
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
return results
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
def test_error(self):
# test assertion if size is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=-1)
TRANSFORMS.build(transform)
# test assertion if size is tuple but one value is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=(224, -1))
TRANSFORMS.build(transform)
# test assertion if size is tuple and len(size) < 2
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=(224, ))
TRANSFORMS.build(transform)
# test assertion if size is tuple len(size) > 2
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=(224, 224, 3))
TRANSFORMS.build(transform)
def test_repr(self):
# test repr
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
assert isinstance(repr(center_crop_module), str)
def test_transform(self):
results = {}
self.reset_results(results, self.original_img, self.gt_semantic_map)
# test CenterCrop when size is int
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
results = center_crop_module(results)
assert results['height'] == 224
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
# test CenterCrop when size is tuple
transform = dict(type='CenterCrop', crop_size=(224, 224))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 224
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
# test CenterCrop when crop_height != crop_width
transform = dict(type='CenterCrop', crop_size=(256, 224))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 256
assert results['width'] == 224
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
88:312]).all()
# test CenterCrop when crop_size is equal to img.shape
img_height, img_width, _ = self.original_img.shape
transform = dict(type='CenterCrop', crop_size=(img_height, img_width))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 300
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
# test CenterCrop when crop_size is larger than img.shape
transform = dict(
type='CenterCrop', crop_size=(img_height * 2, img_width * 2))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 300
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
# test with padding
transform = dict(
type='CenterCrop',
crop_size=(img_height * 2, img_width // 2),
pad_mode='constant',
pad_val=12)
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
transform = dict(
type='CenterCrop',
crop_size=(img_height * 2, img_width // 2),
pad_mode='constant',
pad_val=dict(img=13, seg=33))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
# test CenterCrop when crop_width is smaller than img_width
transform = dict(
type='CenterCrop', crop_size=(img_height, img_width // 2))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == img_height
assert results['width'] == img_width // 2
assert (results['img'] == self.original_img[:, 100:300, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[:,
100:300]).all()
# test CenterCrop when crop_height is smaller than img_height
transform = dict(
type='CenterCrop', crop_size=(img_height // 2, img_width))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == img_height // 2
assert results['width'] == img_width
assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
...]).all()
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
def test_torchvision_compare(self):
# compare results with torchvision
results = {}
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
center_crop_module = torchvision.transforms.CenterCrop(size=224)
pil_img = Image.fromarray(self.original_img)
pil_seg = Image.fromarray(self.gt_semantic_map)
cropped_img = center_crop_module(pil_img)
cropped_img = np.array(cropped_img)
cropped_seg = center_crop_module(pil_seg)
cropped_seg = np.array(cropped_seg)
assert np.equal(results['img'], cropped_img).all()
assert np.equal(results['gt_semantic_seg'], cropped_seg).all()
class TestRandomGrayscale:
@classmethod
def setup_class(cls):
cls.img = np.random.rand(10, 10, 3).astype(np.float32)
def test_repr(self):
# test repr
transform = dict(
type='RandomGrayscale',
prob=1.,
channel_weights=(0.299, 0.587, 0.114),
keep_channel=True)
random_gray_scale_module = TRANSFORMS.build(transform)
assert isinstance(repr(random_gray_scale_module), str)
def test_error(self):
# test invalid argument
transform = dict(type='RandomGrayscale', prob=2)
with pytest.raises(AssertionError):
TRANSFORMS.build(transform)
def test_transform(self):
results = dict()
# test rgb2gray, return the grayscale image with prob = 1.
transform = dict(
type='RandomGrayscale',
prob=1.,
channel_weights=(0.299, 0.587, 0.114),
keep_channel=True)
random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img)
img = random_gray_scale_module(results)['img']
computed_gray = (
self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 +
self.img[:, :, 2] * 0.114)
for i in range(img.shape[2]):
assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4)
assert img.shape == (10, 10, 3)
# test rgb2gray, return the original image with p=0.
transform = dict(type='RandomGrayscale', prob=0.)
random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img)
img = random_gray_scale_module(results)['img']
assert_array_equal(img, self.img)
assert img.shape == (10, 10, 3)
# test image with one channel
transform = dict(type='RandomGrayscale', prob=1.)
results['img'] = self.img[:, :, 0:1]
random_gray_scale_module = TRANSFORMS.build(transform)
img = random_gray_scale_module(results)['img']
assert_array_equal(img, self.img[:, :, 0:1])
assert img.shape == (10, 10, 1)
class TestMultiScaleFlipAug:
@classmethod
def setup_class(cls):
cls.img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
cls.original_img = copy.deepcopy(cls.img)
def test_error(self):
# test assertion if img_scale is None
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=None, transforms=[])
TRANSFORMS.build(transform)
# test assertion if img_scale is not tuple or list of tuple
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=[1333, 800], transforms=[])
TRANSFORMS.build(transform)
# test assertion if flip_direction is not str or list of str
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug',
img_scale=[(1333, 800)],
flip_direction=1,
transforms=[])
TRANSFORMS.build(transform)
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
def test_multi_scale_flip_aug(self):
# test with empty transforms
transform = dict(
type='MultiScaleFlipAug',
transforms=[],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12
# test with flip=False
transform = dict(
type='MultiScaleFlipAug',
transforms=[],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 3
# test with transforms
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transforms_cfg = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12
class TestRandomMultiscaleResize:
@classmethod
def setup_class(cls):
cls.img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
cls.original_img = copy.deepcopy(cls.img)
def reset_results(self, results):
results['img'] = copy.deepcopy(self.original_img)
results['gt_semantic_seg'] = copy.deepcopy(self.original_img)
def test_repr(self):
# test repr
transform = dict(
type='RandomMultiscaleResize', scales=[(1333, 800), (1333, 600)])
random_multiscale_resize = TRANSFORMS.build(transform)
assert isinstance(repr(random_multiscale_resize), str)
def test_error(self):
# test assertion if size is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='RandomMultiscaleResize', scales=[0.5, 1, 2])
TRANSFORMS.build(transform)
def test_random_multiscale_resize(self):
results = dict()
# test with one scale
transform = dict(type='RandomMultiscaleResize', scales=[(1333, 800)])
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results = random_multiscale_resize(results)
assert results['img'].shape == (800, 1333, 3)
# test with multi scales
_scale_choice = [(1333, 800), (1333, 600)]
transform = dict(type='RandomMultiscaleResize', scales=_scale_choice)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results = random_multiscale_resize(results)
assert (results['img'].shape[1],
results['img'].shape[0]) in _scale_choice
# test keep_ratio
transform = dict(
type='RandomMultiscaleResize',
scales=[(900, 600)],
keep_ratio=True)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
_input_ratio = results['img'].shape[0] / results['img'].shape[1]
results = random_multiscale_resize(results)
_output_ratio = results['img'].shape[0] / results['img'].shape[1]
assert_array_almost_equal(_input_ratio, _output_ratio)
# test clip_object_border
gt_bboxes = [[200, 150, 600, 450]]
transform = dict(
type='RandomMultiscaleResize',
scales=[(200, 150)],
clip_object_border=True)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes)
results = random_multiscale_resize(results)
assert results['img'].shape == (150, 200, 3)
assert np.equal(results['gt_bboxes'], np.array([[100, 75, 200,
150]])).all()
transform = dict(
type='RandomMultiscaleResize',
scales=[(200, 150)],
clip_object_border=False)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes)
results = random_multiscale_resize(results)
assert results['img'].shape == (150, 200, 3)
assert np.equal(results['gt_bboxes'], np.array([[100, 75, 300,
225]])).all()
class TestRandomFlip:
def test_init(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment