import sys sys.path.insert(0, './pix2pixlib') import os import pathlib import logging import argparse import json from collections import namedtuple from PIL import Image import numpy as np import torch from nni.utils import merge_parameter from pix2pixlib.data.aligned_dataset import AlignedDataset from pix2pixlib.data import CustomDatasetDataLoader from pix2pixlib.models.pix2pix_model import Pix2PixModel from pix2pixlib.util.util import tensor2im from base_params import get_base_params _logger = logging.getLogger('example_pix2pix') def download_dataset(dataset_name): # code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix assert(dataset_name in ['facades', 'night2day', 'edges2handbags', 'edges2shoes', 'maps']) if os.path.exists('./data/' + dataset_name): _logger.info("Already downloaded dataset " + dataset_name) else: _logger.info("Downloading dataset " + dataset_name) if not os.path.exists('./data/'): pathlib.Path('./data/').mkdir(parents=True, exist_ok=True) pathlib.Path('./data/' + dataset_name).mkdir(parents=True, exist_ok=True) URL = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{}.tar.gz'.format(dataset_name) TAR_FILE = './data/{}.tar.gz'.format(dataset_name) TARGET_DIR = './data/{}/'.format(dataset_name) os.system('wget -N {} -O {}'.format(URL, TAR_FILE)) pathlib.Path(TARGET_DIR).mkdir(parents=True, exist_ok=True) os.system('tar -zxvf {} -C ./data/'.format(TAR_FILE)) os.system('rm ' + TAR_FILE) def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Pix2pix Example') # required arguments parser.add_argument('-c', '--checkpoint', type=str, required=True, help='Checkpoint directory') parser.add_argument('-p', '--parameter_cfg', type=str, required=True, help='parameter.cfg file generated by nni trial') parser.add_argument('-d', '--dataset', type=str, required=True, help='dataset name (facades, night2day, edges2handbags, edges2shoes, maps)') parser.add_argument('-o', '--output_dir', type=str, required=True, help='Where to save the test results') # Settings that may be overrided by parameters from nni parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discriminator filters in the first conv layer') parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') parser.add_argument('--lambda_L1', type=float, default=100, help='weight of L1 loss in the generator objective') # Additional training settings parser.add_argument('--batch_size', type=int, default=1, help='input batch size for training (default: 1)') parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') args, _ = parser.parse_known_args() return args def main(test_params): test_config = namedtuple('Struct', test_params.keys())(*test_params.values()) assert os.path.exists(test_config.checkpoint), "Checkpoint does not exist" download_dataset(test_config.dataset) test_dataset = AlignedDataset(test_config) test_dataset = CustomDatasetDataLoader(test_config, test_dataset) _logger.info('Number of testing images = {}'.format(len(test_dataset))) model = Pix2PixModel(test_config) model.setup(test_config) if test_config.eval: model.eval() for i, data in enumerate(test_dataset): print('Testing on {} image {}'.format(test_config.dataset, i), end='\r') model.set_input(data) model.test() visuals = model.get_current_visuals() cur_input = tensor2im(visuals['real_A']) cur_label = tensor2im(visuals['real_B']) cur_output = tensor2im(visuals['fake_B']) image_name = '{}_test_{}.png'.format(test_config.dataset, i) Image.fromarray(cur_input).save(os.path.join(test_config.output_dir, 'input', image_name)) Image.fromarray(cur_label).save(os.path.join(test_config.output_dir, 'label', image_name)) Image.fromarray(cur_output).save(os.path.join(test_config.output_dir, 'output', image_name)) _logger.info("Images successfully saved to " + test_config.output_dir) if __name__ == '__main__': params_from_cl = vars(parse_args()) _, test_params = get_base_params(params_from_cl['dataset'], params_from_cl['checkpoint']) test_params.update(params_from_cl) with open(test_params['parameter_cfg'], 'r') as f: params_from_nni = json.loads(f.readline().strip())['parameters'] test_params = merge_parameter(test_params, params_from_nni) pathlib.Path(params_from_cl['output_dir'] + '/input').mkdir(parents=True, exist_ok=True) pathlib.Path(params_from_cl['output_dir'] + '/label').mkdir(parents=True, exist_ok=True) pathlib.Path(params_from_cl['output_dir'] + '/output').mkdir(parents=True, exist_ok=True) main(test_params)