test_ffhq_dataset.py 1.27 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import math
import os
import torch
import torchvision.utils

from basicsr.data import build_dataloader, build_dataset


def main():
    """Test FFHQ dataset."""
    opt = {}
    opt['dist'] = False
    opt['gpu_ids'] = [0]
    opt['phase'] = 'train'

    opt['name'] = 'FFHQ'
    opt['type'] = 'FFHQDataset'

    opt['dataroot_gt'] = 'datasets/ffhq/ffhq_256.lmdb'
    opt['io_backend'] = dict(type='lmdb')

    opt['use_hflip'] = True
    opt['mean'] = [0.5, 0.5, 0.5]
    opt['std'] = [0.5, 0.5, 0.5]

    opt['num_worker_per_gpu'] = 1
    opt['batch_size_per_gpu'] = 4

    opt['dataset_enlarge_ratio'] = 1

    os.makedirs('tmp', exist_ok=True)

    dataset = build_dataset(opt)
    data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)

    nrow = int(math.sqrt(opt['batch_size_per_gpu']))
    padding = 2 if opt['phase'] == 'train' else 0

    print('start...')
    for i, data in enumerate(data_loader):
        if i > 5:
            break
        print(i)

        gt = data['gt']
        print(torch.min(gt), torch.max(gt))
        gt_path = data['gt_path']
        print(gt_path)
        torchvision.utils.save_image(
            gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=True, range=(-1, 1))


if __name__ == '__main__':
    main()