test_grow_scale_img_dataset.py 3.52 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import pytest

from mmgen.datasets import GrowScaleImgDataset


class TestGrowScaleImgDataset:

    @classmethod
    def setup_class(cls):
        cls.imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image')
        cls.imgs_roots = {
            '4': cls.imgs_root,
            '8': osp.join(cls.imgs_root, 'img_root'),
            '32': osp.join(cls.imgs_root, 'img_root', 'grass')
        }
        cls.default_pipeline = [
            dict(type='LoadImageFromFile', io_backend='disk', key='real_img')
        ]
        cls.len_per_stage = 10
        cls.gpu_samples_base = 2

    def test_dynamic_unconditional_img_dataset(self):
        dataset = GrowScaleImgDataset(
            self.imgs_roots,
            self.default_pipeline,
            self.len_per_stage,
            gpu_samples_base=self.gpu_samples_base)
        assert len(dataset) == 10
        img = dataset[2]['real_img']
        assert img.ndim == 3
        assert repr(dataset) == (
            f'dataset_name: {dataset.__class__}, '
            f'total {10} images in imgs_root: {self.imgs_root}')
        assert dataset.samples_per_gpu == 2

        dataset.update_annotations(8)
        assert len(dataset) == 10
        img = dataset[2]['real_img']
        assert img.ndim == 3
        assert repr(dataset) == (f'dataset_name: {dataset.__class__}, '
                                 f'total {10} images in imgs_root:'
                                 f' {osp.join(self.imgs_root, "img_root")}')
        assert dataset.samples_per_gpu == 2

        dataset = GrowScaleImgDataset(
            self.imgs_roots,
            self.default_pipeline,
            20,
            gpu_samples_base=self.gpu_samples_base,
            gpu_samples_per_scale={
                '4': 10,
                '16': 13
            })
        assert len(dataset) == 20
        img = dataset[2]['real_img']
        assert img.ndim == 3
        assert repr(dataset) == (
            f'dataset_name: {dataset.__class__}, '
            f'total {20} images in imgs_root: {self.imgs_root}')
        assert dataset.samples_per_gpu == 10

        dataset.update_annotations(8)
        assert len(dataset) == 20
        img = dataset[2]['real_img']
        assert img.ndim == 3
        assert repr(dataset) == (f'dataset_name: {dataset.__class__}, '
                                 f'total {20} images in imgs_root:'
                                 f' {osp.join(self.imgs_root, "img_root")}')
        assert dataset.samples_per_gpu == 2

        dataset = GrowScaleImgDataset(
            self.imgs_roots, self.default_pipeline, 5, test_mode=True)
        assert len(dataset) == 5
        img = dataset[2]['real_img']
        assert img.ndim == 3
        assert repr(dataset) == (
            f'dataset_name: {dataset.__class__}, '
            f'total {5} images in imgs_root: {self.imgs_root}')

        dataset.update_annotations(24)
        assert len(dataset) == 5
        img = dataset[2]['real_img']
        assert img.ndim == 3
        _path_str = osp.join(self.imgs_root, 'img_root', 'grass')
        assert repr(dataset) == (f'dataset_name: {dataset.__class__}, '
                                 f'total {5} images in imgs_root: {_path_str}')

        with pytest.raises(AssertionError):
            _ = GrowScaleImgDataset(
                self.imgs_root,
                self.default_pipeline,
                10,
                gpu_samples_per_scale=10)

        with pytest.raises(AssertionError):
            _ = GrowScaleImgDataset(10, self.default_pipeline, 10.)