test_persistent_worker.py 908 Bytes
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

from mmgen.datasets.builder import build_dataloader, build_dataset


class TestPersistentWorker(object):

    @classmethod
    def setup_class(cls):
        imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image')
        train_pipeline = [
            dict(type='LoadImageFromFile', io_backend='disk', key='real_img')
        ]
        cls.config = dict(
            samples_per_gpu=1,
            workers_per_gpu=4,
            drop_last=True,
            persistent_workers=True)

        cls.data_cfg = dict(
            type='UnconditionalImageDataset',
            imgs_root=imgs_root,
            pipeline=train_pipeline,
            test_mode=False)

    def test_persistent_worker(self):
        # test non-persistent-worker
        dataset = build_dataset(self.data_cfg)
        build_dataloader(dataset, **self.config)