Unverified Commit 994a2d23 authored by Bill Wu's avatar Bill Wu Committed by GitHub
Browse files

[Example] Conditional GAN in pytorch (#3687)

parent 68644f59
Pix2pix example
=================
Overview
--------
`Pix2pix <https://arxiv.org/abs/1611.07004>`__ is a conditional generative adversial network (conditional GAN) framework proposed by Isola et. al. in 2016 targeting at solving image-to-image translation problems. This framework performs well in a wide range of image generation problems. In the original paper, the authors demonstrate how to use pix2pix to solve the following image translation problems: 1) labels to street scene; 2) labels to facade; 3) BW to Color; 4) Aerial to Map; 5) Day to Night and 6) Edges to Photo. If you are interested, please read more in the `official project page <https://phillipi.github.io/pix2pix/>`__ . In this example, we use pix2pix to introduce how to use NNI for tuning conditional GANs.
**Goals**
^^^^^^^^^^^^^
Although GANs are known to be able to generate high-resolution realistic images, they are generally fragile and difficult to optimize, and mode collapse can happen during training due to improper optimization setting, loss formulation, model architecture, weight initialization, or even data augmentation patterns. The goal of this tutorial is to leverage NNI hyperparameter tuning tools to automatically find a good setting for these important factors.
In this example, we aim at selecting the following hyperparameters automatically:
* ``ngf``: number of generator filters in the last conv layer
* ``ndf``: number of discriminator filters in the first conv layer
* ``netG``: generator architecture
* ``netD``: discriminator architecture
* ``norm``: normalization type
* ``init_type``: weight initialization method
* ``lr``: initial learning rate for adam
* ``beta1``: momentum term of adam
* ``lr_policy``: learning rate policy
* ``gan_mode``: type of GAN objective
* ``lambda_L1``: weight of L1 loss in the generator objective
**Experiments**
^^^^^^^^^^^^^^^^^^^^
Preparations
^^^^^^^^^^^^
This example requires the GPU version of PyTorch. PyTorch installation should be chosen based on system, python version, and cuda version.
Please refer to the detailed instruction of installing `PyTorch <https://pytorch.org/get-started/locally/>`__
Next, run the following shell script to clone the repository maintained by the original authors of pix2pix. This example relies on the implementations in this repository.
.. code-block:: bash
./setup.sh
Pix2pix with NNI
^^^^^^^^^^^^^^^^^
**Search Space**
We summarize the range of values for each hyperparameter mentioned above into a single search space json object.
.. code-block:: json
{
"ngf": {"_type":"choice","_value":[16, 32, 64, 128, 256]},
"ndf": {"_type":"choice","_value":[16, 32, 64, 128, 256]},
"netG": {"_type":"choice","_value":["resnet_9blocks", "unet_256"]},
"netD": {"_type":"choice","_value":["basic", "pixel", "n_layers"]},
"norm": {"_type":"choice","_value":["batch", "instance", "none"]},
"init_type": {"_type":"choice","_value":["xavier", "normal", "kaiming", "orthogonal"]},
"lr":{"_type":"choice","_value":[0.0001, 0.0002, 0.0005, 0.001, 0.005, 0.01, 0.1]},
"beta1":{"_type":"uniform","_value":[0, 1]},
"lr_policy": {"_type":"choice","_value":["linear", "step", "plateau", "cosine"]},
"gan_mode": {"_type":"choice","_value":["vanilla", "lsgan", "wgangp"]} ,
"lambda_L1": {"_type":"choice","_value":[1, 5, 10, 100, 250, 500]}
}
Starting from v2.0, the search space is directly included in the config. Please find the example here: :githublink:`config.yml <examples/trials/pix2pix-pytorch/config.yml>`
**Trial**
To experiment on this set of hyperparameters using NNI, we have to write a trial code, which receives a set of parameter settings from NNI, trains a generator and discriminator using these parameters, and then reports the final scores back to NNI. In the experiment, NNI repeatedly calls this trial code, passing in different set of hyperparameter settings. It is important that the following three lines are incorporated in the trial code:
* Use ``nni.get_next_parameter()`` to get next hyperparameter set.
* (Optional) Use ``nni.report_intermediate_result(score)`` to report the intermediate result after finishing each epoch.
* Use ``nni.report_final_result(score)`` to report the final result before the trial ends.
Implemented code directory: :githublink:`pix2pix.py <examples/trials/pix2pix-pytorch/pix2pix.py>`
Some notes on the implementation:
* The trial code for this example is adapted from the `repository maintained by the authors of Pix2pix and CycleGAN <https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix>`__ . You can also use your previous code directly. Please refer to `How to define a trial <Trials.rst>`__ for modifying the code.
* By default, the code uses the dataset "facades". It also supports the datasets "night2day", "edges2handbags", "edges2shoes", and "maps".
* For "facades", 200 epochs are enough for the model to converge to a point where the difference between models trained with different hyperparameters are salient enough for evaluation. If you are using other datasets, please consider increasing the ``n_epochs`` and ``n_epochs_decay`` parameters by either passing them as arguments when calling ``pix2pix.py`` in the config file (discussed below) or changing the ``pix2pix.py`` directly. Also, for "facades", 200 epochs are enought for the final training, while the number may vary for other datasets.
* In this example, we use L1 loss on the test set as the score to report to NNI. Although L1 is by no means a comprehensive measure of image generation performance, at most times it makes sense for evaluating pix2pix models with similar architectural setup. In this example, for the hyperparameters we experiment on, a higher L1 score generally indicates a higher generation performance.
**Config**
Here is the example config of running this experiment on local (with a single GPU):
code directory: :githublink:`examples/trials/pix2pix-pytorch/config.yml <examples/trials/pix2pix-pytorch/config.yml>`
To have a full glance on our implementation, check: :githublink:`examples/trials/pix2pix-pytorch/ <examples/trials/pix2pix-pytorch>`
Launch the experiment
^^^^^^^^^^^^^^^^^^^^^
We are ready for the experiment, let's now **run the config.yml file from your command line to start the experiment**.
.. code-block:: bash
nnictl create --config nni/examples/trials/pix2pix-pytorch/config.yml
Collecting the Results
^^^^^^^^^^^^^^^^^^^^^^
By default, our trial code saves the final trained model for each trial in the ``checkpoints/`` directory in the trial directory of the NNI experiment. The ``latest_net_G.pth`` and ``latest_net_D.pth`` correspond to the save checkpoints for the generator and the discriminator.
To make it easier to run inference and see the generated images, we also incorporate a simple inference code here: :githublink:`test.py <examples/trials/pix2pix-pytorch/test.py>`
To use the code, run the following command:
.. code-block:: bash
python3 test.py -c CHECKPOINT -p PARAMETER_CFG -d DATASET_NAME -o OUTPUT_DIR
``CHECKPOINT`` is the directory saving the checkpoints (e.g., the ``checkpoints/`` directory in the trial directory). ``PARAMETER_CFG`` is the ``parameter.cfg`` file generated by NNI recording the hyperparameter settings. This file can be found in the trial directory created by NNI.
Results and Discussions
^^^^^^^^^^^^^^^^^^^^^^^
Following the previous steps, we ran the example for 40 trials using the TPE tuner. We found that the best-performing parameters on the 'facades' dataset to be the following set.
.. code-block:: json
{
"ngf": 16,
"ndf": 128,
"netG": "unet_256",
"netD": "pixel",
"norm": "none",
"init_type": "normal",
"lr": 0.0002,
"beta1": 0.6954,
"lr_policy": "step",
"gan_mode": "lsgan",
"lambda_L1": 500
}
Meanwhile, we compare the results with the model training using the following default empirical hyperparameter settings:
.. code-block:: json
{
"ngf": 128,
"ndf": 128,
"netG": "unet_256",
"netD": "basic",
"norm": "batch",
"init_type": "xavier",
"lr": 0.0002,
"beta1": 0.5,
"lr_policy": "linear",
"gan_mode": "lsgan",
"lambda_L1": 100
}
We can observe that for learning rate (0.0002), the generator architecture (U-Net), and gan objective (LSGAN), the two results agree with each other. This is also consistent with the widely accepted practice on this dataset. Meanwhile, the hyperparameters "beta1", "lambda_L1", "ngf", and "ndf" are slightly changed in the NNI's found solution to fit the target dataset. We found that the parameters searched by NNI outperforms the empirical parameters on the facades dataset both in terms of L1 loss and the visual qualities of the images. While the search hyperparameter has a L1 loss of 0.3317 on the test set of facades, the empirical hyperparameters can only achieve a L1 loss of 0.4148. The following image shows some sample results of facades test set input-output pairs produced by the model with hyperparameters tuned with NNI.
.. image:: ../../img/pix2pix_pytorch_facades.png
:target: ../../img/pix2pix_pytorch_facades.png
:alt:
...@@ -8,4 +8,5 @@ Examples ...@@ -8,4 +8,5 @@ Examples
MNIST<./TrialExample/MnistExamples> MNIST<./TrialExample/MnistExamples>
Cifar10<./TrialExample/Cifar10Examples> Cifar10<./TrialExample/Cifar10Examples>
Scikit-learn<./TrialExample/SklearnExamples> Scikit-learn<./TrialExample/SklearnExamples>
GBDT<./TrialExample/GbdtExample> GBDT<./TrialExample/GbdtExample>
\ No newline at end of file Pix2pix<./TrialExample/Pix2pixExample>
# datasets
data/
# pix2pix library
pix2pixlib/
\ No newline at end of file
def get_base_params(dataset_name, checkpoint_dir):
params = {}
# change name and gpuid later
basic_params = {'dataset': dataset_name,
'dataroot': './data/' + dataset_name,
'name': '',
'gpu_ids': [0],
'checkpoints_dir': checkpoint_dir,
'verbose': False,
'print_freq': 100
}
params.update(basic_params)
dataset_params = {'dataset_mode': 'aligned',
'direction': 'BtoA',
'num_threads': 4,
'max_dataset_size': float('inf'),
'preprocess': 'resize_and_crop',
'display_winsize': 256,
'input_nc': 3,
'output_nc': 3}
params.update(dataset_params)
model_params = {'model': 'pix2pix',
# 'ngf': 64,
# 'ndf': 64,
# 'netD': 'basic',
# 'netG': 'unet_256',
'n_layers_D': 3,
# 'norm': 'batch',
# 'gan_mode': 'lsgan',
# 'init_type': 'normal',
'init_gain': 0.02,
'no_dropout': False}
params.update(model_params)
train_params = {'phase': 'train',
'isTrain': True,
'serial_batches': False,
'load_size': 286,
'crop_size': 256,
'no_flip': False,
# 'batch_size': 1,
# 'beta1': 0.5,
'pool_size': 0,
# 'lr_policy': 'linear',
'lr_decay_iters': 50,
#'lr': 0.0002,
# 'lambda_L1': 100,
'epoch_count': 1,
# 'n_epochs': 10, # 100
# 'n_epochs_decay': 0, # 100
'continue_train': False}
train_params.update(params)
test_params = {'phase': 'test',
'isTrain': False,
'load_iter': -1,
'epoch': 'latest',
'load_size': 256,
'crop_size': 256,
# 'batch_size': 1,
'serial_batches': True,
'no_flip': True,
'eval': True}
test_params.update(params)
return train_params, test_params
experimentName: example_pix2pix
searchSpace:
ngf:
_type: choice
_value: [16, 32, 64, 128]
ndf:
_type: choice
_value: [16, 32, 64, 128]
netG:
_type: choice
_value: ["unet_256", "resnet_9blocks"]
netD:
_type: choice
_value: ["basic", "pixel", "n_layers"]
norm:
_type: choice
_value: ["batch", "instance", "none"]
init_type:
_type: choice
_value: ["xavier", "normal", "kaiming", "orthogonal"]
lr:
_type: choice
_value: [0.0001, 0.0002, 0.0005, 0.001, 0.005, 0.01, 0.1]
beta1:
_type: uniform
_value: [0, 1]
lr_policy:
_type: choice
_value: ["linear", "step", "plateau", "cosine"]
gan_mode:
_type: choice
_value: ["vanilla", "lsgan", "wgangp"]
lambda_L1:
_type: choice
_value: [1, 5, 10, 100, 250, 500]
trainingService:
platform: local
useActiveGpu: true
gpuIndices: '0'
trialCodeDirectory: .
trialCommand: python3 pix2pix.py
trialConcurrency: 1
trialGpuNumber: 1
tuner:
name: TPE
classArgs:
optimize_mode: minimize
\ No newline at end of file
import sys
sys.path.insert(0, './pix2pixlib')
import os
import pathlib
import logging
import time
import argparse
from collections import namedtuple
import numpy as np
import torch
import torch.utils.data as data
import nni
from nni.utils import merge_parameter
from pix2pixlib.data.aligned_dataset import AlignedDataset
from pix2pixlib.models.pix2pix_model import Pix2PixModel
from base_params import get_base_params
_logger = logging.getLogger('example_pix2pix')
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt, ds):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
self.dataset = ds
self.dataloader = data.DataLoader(self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads))
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
def download_dataset(dataset_name):
# code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
assert(dataset_name in ['facades', 'night2day', 'edges2handbags', 'edges2shoes', 'maps'])
if os.path.exists('./data/' + dataset_name):
_logger.info("Already downloaded dataset " + dataset_name)
else:
_logger.info("Downloading dataset " + dataset_name)
if not os.path.exists('./data/'):
pathlib.Path('./data/').mkdir(parents=True, exist_ok=True)
pathlib.Path('./data/' + dataset_name).mkdir(parents=True, exist_ok=True)
URL = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{}.tar.gz'.format(dataset_name)
TAR_FILE = './data/{}.tar.gz'.format(dataset_name)
TARGET_DIR = './data/{}/'.format(dataset_name)
os.system('wget -N {} -O {}'.format(URL, TAR_FILE))
pathlib.Path(TARGET_DIR).mkdir(parents=True, exist_ok=True)
os.system('tar -zxvf {} -C ./data/'.format(TAR_FILE))
os.system('rm ' + TAR_FILE)
def setup_trial_checkpoint_dir():
checkpoint_dir = os.environ['NNI_OUTPUT_DIR'] + '/checkpoints/'
pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
return checkpoint_dir
def parse_args():
# Settings that may be overrided by parameters from nni
parser = argparse.ArgumentParser(description='PyTorch Pix2pix Example')
parser.add_argument('--ngf', type=int, default=64,
help='# of generator filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64,
help='# of discriminator filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic',
help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks',
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--beta1', type=float, default=0.5,
help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002,
help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='linear',
help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--gan_mode', type=str, default='lsgan',
help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--lambda_L1', type=float, default=100,
help='weight of L1 loss in the generator objective')
# Additional training settings
parser.add_argument('--batch_size', type=int, default=1,
help='input batch size for training (default: 1)')
parser.add_argument('--n_epochs', type=int, default=100,
help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=100,
help='number of epochs to linearly decay learning rate to zero')
args, _ = parser.parse_known_args()
return args
def evaluate_L1(config, model, dataset):
if config.eval:
model.eval()
scores = []
for i, data in enumerate(dataset):
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals()
score = torch.mean(torch.abs(visuals['fake_B']-visuals['real_B'])).detach().cpu().numpy()
scores.append(score)
return np.mean(np.array(scores))
def main(dataset_name, train_params, test_params):
download_dataset(dataset_name)
train_config = namedtuple('Struct', train_params.keys())(*train_params.values())
test_config = namedtuple('Struct', test_params.keys())(*test_params.values())
train_dataset, test_dataset = AlignedDataset(train_config), AlignedDataset(test_config)
print(train_dataset, train_config)
train_dataset = CustomDatasetDataLoader(train_config, train_dataset)
test_dataset = CustomDatasetDataLoader(test_config, test_dataset)
_logger.info('Number of training images = {}'.format(len(train_dataset)))
_logger.info('Number of testing images = {}'.format(len(test_dataset)))
model = Pix2PixModel(train_config)
model.setup(train_config)
# training
total_iters = 0 # the total number of training iterations
for epoch in range(train_config.epoch_count, train_config.n_epochs + train_config.n_epochs_decay + 1):
_logger.info('Training epoch {}'.format(epoch))
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0
model.update_learning_rate()
for i, data in enumerate(train_dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % train_config.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_iters += train_config.batch_size
epoch_iter += train_config.batch_size
model.set_input(data) # unpack data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
iter_data_time = time.time()
_logger.info('End of epoch {} / {} \t Time Taken: {} sec'.format(epoch, train_config.n_epochs + train_config.n_epochs_decay, time.time() - epoch_start_time))
model.save_networks('latest')
_logger.info("Training done. Saving the final model.")
l1_score = evaluate_L1(test_config, model, test_dataset)
_logger.info("The final L1 loss the test set is {}".format(l1_score))
nni.report_final_result(l1_score)
if __name__ == '__main__':
dataset_name = 'facades'
checkpoint_dir = setup_trial_checkpoint_dir()
params_from_cl = vars(parse_args())
params_for_tuning = nni.get_next_parameter()
train_params, test_params = get_base_params(dataset_name, checkpoint_dir)
train_params.update(params_from_cl)
test_params.update(params_from_cl)
train_params = merge_parameter(train_params, params_for_tuning)
main(dataset_name, train_params, test_params)
#!/bin/bash
# download pix2pix repository
if [ ! -d './pix2pixlib' ] ; then
git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git pix2pixlib
fi
import sys
sys.path.insert(0, './pix2pixlib')
import os
import pathlib
import logging
import argparse
import json
from collections import namedtuple
from PIL import Image
import numpy as np
import torch
from nni.utils import merge_parameter
from pix2pixlib.data.aligned_dataset import AlignedDataset
from pix2pixlib.data import CustomDatasetDataLoader
from pix2pixlib.models.pix2pix_model import Pix2PixModel
from pix2pixlib.util.util import tensor2im
from base_params import get_base_params
_logger = logging.getLogger('example_pix2pix')
def download_dataset(dataset_name):
# code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
assert(dataset_name in ['facades', 'night2day', 'edges2handbags', 'edges2shoes', 'maps'])
if os.path.exists('./data/' + dataset_name):
_logger.info("Already downloaded dataset " + dataset_name)
else:
_logger.info("Downloading dataset " + dataset_name)
if not os.path.exists('./data/'):
pathlib.Path('./data/').mkdir(parents=True, exist_ok=True)
pathlib.Path('./data/' + dataset_name).mkdir(parents=True, exist_ok=True)
URL = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{}.tar.gz'.format(dataset_name)
TAR_FILE = './data/{}.tar.gz'.format(dataset_name)
TARGET_DIR = './data/{}/'.format(dataset_name)
os.system('wget -N {} -O {}'.format(URL, TAR_FILE))
pathlib.Path(TARGET_DIR).mkdir(parents=True, exist_ok=True)
os.system('tar -zxvf {} -C ./data/'.format(TAR_FILE))
os.system('rm ' + TAR_FILE)
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch Pix2pix Example')
# required arguments
parser.add_argument('-c', '--checkpoint', type=str, required=True,
help='Checkpoint directory')
parser.add_argument('-p', '--parameter_cfg', type=str, required=True,
help='parameter.cfg file generated by nni trial')
parser.add_argument('-d', '--dataset', type=str, required=True,
help='dataset name (facades, night2day, edges2handbags, edges2shoes, maps)')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Where to save the test results')
# Settings that may be overrided by parameters from nni
parser.add_argument('--ngf', type=int, default=64,
help='# of generator filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64,
help='# of discriminator filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic',
help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks',
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--beta1', type=float, default=0.5,
help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002,
help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='linear',
help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--gan_mode', type=str, default='lsgan',
help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--lambda_L1', type=float, default=100,
help='weight of L1 loss in the generator objective')
# Additional training settings
parser.add_argument('--batch_size', type=int, default=1,
help='input batch size for training (default: 1)')
parser.add_argument('--n_epochs', type=int, default=100,
help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=100,
help='number of epochs to linearly decay learning rate to zero')
args, _ = parser.parse_known_args()
return args
def main(test_params):
test_config = namedtuple('Struct', test_params.keys())(*test_params.values())
assert os.path.exists(test_config.checkpoint), "Checkpoint does not exist"
download_dataset(test_config.dataset)
test_dataset = AlignedDataset(test_config)
test_dataset = CustomDatasetDataLoader(test_config, test_dataset)
_logger.info('Number of testing images = {}'.format(len(test_dataset)))
model = Pix2PixModel(test_config)
model.setup(test_config)
if test_config.eval:
model.eval()
for i, data in enumerate(test_dataset):
print('Testing on {} image {}'.format(test_config.dataset, i), end='\r')
model.set_input(data)
model.test()
visuals = model.get_current_visuals()
cur_input = tensor2im(visuals['real_A'])
cur_label = tensor2im(visuals['real_B'])
cur_output = tensor2im(visuals['fake_B'])
image_name = '{}_test_{}.png'.format(test_config.dataset, i)
Image.fromarray(cur_input).save(os.path.join(test_config.output_dir, 'input', image_name))
Image.fromarray(cur_label).save(os.path.join(test_config.output_dir, 'label', image_name))
Image.fromarray(cur_output).save(os.path.join(test_config.output_dir, 'output', image_name))
_logger.info("Images successfully saved to " + test_config.output_dir)
if __name__ == '__main__':
params_from_cl = vars(parse_args())
_, test_params = get_base_params(params_from_cl['dataset'], params_from_cl['checkpoint'])
test_params.update(params_from_cl)
with open(test_params['parameter_cfg'], 'r') as f:
params_from_nni = json.loads(f.readline().strip())['parameters']
test_params = merge_parameter(test_params, params_from_nni)
pathlib.Path(params_from_cl['output_dir'] + '/input').mkdir(parents=True, exist_ok=True)
pathlib.Path(params_from_cl['output_dir'] + '/label').mkdir(parents=True, exist_ok=True)
pathlib.Path(params_from_cl['output_dir'] + '/output').mkdir(parents=True, exist_ok=True)
main(test_params)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment