test_normalize.py 4.47 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest

from mmgen.datasets.pipelines import Normalize


class TestAugmentations(object):

    @staticmethod
    def assert_img_equal(img, ref_img, ratio_thr=0.999):
        """Check if img and ref_img are matched approximately."""
        assert img.shape == ref_img.shape
        assert img.dtype == ref_img.dtype
        area = ref_img.shape[-1] * ref_img.shape[-2]
        diff = np.abs(img.astype('int32') - ref_img.astype('int32'))
        assert np.sum(diff <= 1) / float(area) > ratio_thr

    @staticmethod
    def check_keys_contain(result_keys, target_keys):
        """Check if all elements in target_keys is in result_keys."""
        return set(target_keys).issubset(set(result_keys))

    def check_normalize(self, origin_img, result_img, norm_cfg):
        """Check if the origin_img are normalized correctly into result_img in
        a given norm_cfg."""
        target_img = result_img.copy()
        target_img *= norm_cfg['std'][None, None, :]
        target_img += norm_cfg['mean'][None, None, :]
        if norm_cfg['to_rgb']:
            target_img = target_img[:, ::-1, ...].copy()
        self.assert_img_equal(origin_img, target_img)

    def test_normalize(self):
        with pytest.raises(TypeError):
            Normalize(['alpha'], dict(mean=[123.675, 116.28, 103.53]),
                      [58.395, 57.12, 57.375])

        with pytest.raises(TypeError):
            Normalize(['alpha'], [123.675, 116.28, 103.53],
                      dict(std=[58.395, 57.12, 57.375]))

        target_keys = ['merged', 'img_norm_cfg']

        merged = np.random.rand(240, 320, 3).astype(np.float32)
        results = dict(merged=merged)
        config = dict(
            keys=['merged'],
            mean=[123.675, 116.28, 103.53],
            std=[58.395, 57.12, 57.375],
            to_rgb=False)
        normalize = Normalize(**config)
        normalize_results = normalize(results)
        assert self.check_keys_contain(normalize_results.keys(), target_keys)
        self.check_normalize(merged, normalize_results['merged'],
                             normalize_results['img_norm_cfg'])

        merged = np.random.rand(240, 320, 3).astype(np.float32)
        results = dict(merged=merged)
        config = dict(
            keys=['merged'],
            mean=[123.675, 116.28, 103.53],
            std=[58.395, 57.12, 57.375],
            to_rgb=True)
        normalize = Normalize(**config)
        normalize_results = normalize(results)
        assert self.check_keys_contain(normalize_results.keys(), target_keys)
        self.check_normalize(merged, normalize_results['merged'],
                             normalize_results['img_norm_cfg'])

        assert normalize.__repr__() == (
            normalize.__class__.__name__ +
            f"(keys={ ['merged']}, mean={np.array([123.675, 116.28, 103.53])},"
            f' std={np.array([58.395, 57.12, 57.375])}, to_rgb=True)')

        # input is an image list
        merged = np.random.rand(240, 320, 3).astype(np.float32)
        merged_2 = np.random.rand(240, 320, 3).astype(np.float32)
        results = dict(merged=[merged, merged_2])
        config = dict(
            keys=['merged'],
            mean=[123.675, 116.28, 103.53],
            std=[58.395, 57.12, 57.375],
            to_rgb=False)
        normalize = Normalize(**config)
        normalize_results = normalize(results)
        assert self.check_keys_contain(normalize_results.keys(), target_keys)
        self.check_normalize(merged, normalize_results['merged'][0],
                             normalize_results['img_norm_cfg'])
        self.check_normalize(merged_2, normalize_results['merged'][1],
                             normalize_results['img_norm_cfg'])

        merged = np.random.rand(240, 320, 3).astype(np.float32)
        merged_2 = np.random.rand(240, 320, 3).astype(np.float32)
        results = dict(merged=[merged, merged_2])
        config = dict(
            keys=['merged'],
            mean=[123.675, 116.28, 103.53],
            std=[58.395, 57.12, 57.375],
            to_rgb=True)
        normalize = Normalize(**config)
        normalize_results = normalize(results)
        assert self.check_keys_contain(normalize_results.keys(), target_keys)
        self.check_normalize(merged, normalize_results['merged'][0],
                             normalize_results['img_norm_cfg'])
        self.check_normalize(merged_2, normalize_results['merged'][1],
                             normalize_results['img_norm_cfg'])