Commit 3bfb5c80 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
This image diff could not be displayed because it is too large. You can view the blob instead.
import os
import torch
import sys
class BaseModel(torch.nn.Module):
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=''):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
else:
#network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path))
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3,0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
pass
import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
n_local_enhancers, n_blocks_local, norm_layer)
elif netG == 'encoder':
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
else:
raise('generator not implemented!')
print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
##############################################################################
# Generator
##############################################################################
class LocalEnhancer(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
super(LocalEnhancer, self).__init__()
self.n_local_enhancers = n_local_enhancers
###### global generator model #####
ngf_global = ngf * (2**n_local_enhancers)
model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
self.model = nn.Sequential(*model_global)
###### local enhancer layers #####
for n in range(1, n_local_enhancers+1):
### downsample
ngf_global = ngf * (2**(n_local_enhancers-n))
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
norm_layer(ngf_global), nn.ReLU(True),
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf_global * 2), nn.ReLU(True)]
### residual blocks
model_upsample = []
for i in range(n_blocks_local):
model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
### upsample
model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(ngf_global), nn.ReLU(True)]
### final convolution
if n == n_local_enhancers:
model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, input):
### create input pyramid
input_downsampled = [input]
for i in range(self.n_local_enhancers):
input_downsampled.append(self.downsample(input_downsampled[-1]))
### output at coarest level
output_prev = self.model(input_downsampled[-1])
### build up one layer at a time
for n_local_enhancers in range(1, self.n_local_enhancers+1):
model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
output_prev = model_upsample(model_downsample(input_i) + output_prev)
return output_prev
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert(n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class Encoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf), nn.ReLU(True)]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), nn.ReLU(True)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, inst):
outputs = self.model(input)
# instance-wise average pooling
outputs_mean = outputs.clone()
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b:b+1] == int(i)).nonzero() # n x 4
for j in range(self.output_nc):
output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
return outputs_mean
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers+2):
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
self.gen_features = self.use_features and not self.opt.load_features
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
### Encoder network
if self.gen_features:
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
if self.gen_features:
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
import sys
if sys.version_info >= (3,0):
finetune_list = set()
else:
from sets import Set
finetune_list = Set()
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
if key.startswith('model' + str(opt.n_local_enhancers)):
params += [value]
finetune_list.add(key.split('.')[0])
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
print('The layers that are finetuned are ', sorted(finetune_list))
else:
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
fake_image = self.netG.forward(input_concat)
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
def inference(self, label, inst, image=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
if self.opt.data_type==16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
if self.opt.data_type==16:
return edge.half()
else:
return edge.float()
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)
import torch
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
import os
from PIL import Image
import util.util as util
from .base_model import BaseModel
from . import networks
class UIModel(BaseModel):
def name(self):
return 'UIModel'
def initialize(self, opt):
assert(not opt.isTrain)
BaseModel.initialize(self, opt)
self.use_features = opt.instance_feat or opt.label_feat
netG_input_nc = opt.label_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
self.load_network(self.netG, 'G', opt.which_epoch)
print('---------- Networks initialized -------------')
def toTensor(self, img, normalize=False):
tensor = torch.from_numpy(np.array(img, np.int32, copy=False))
tensor = tensor.view(1, img.size[1], img.size[0], len(img.mode))
tensor = tensor.transpose(1, 2).transpose(1, 3).contiguous()
if normalize:
return (tensor.float()/255.0 - 0.5) / 0.5
return tensor.float()
def load_image(self, label_path, inst_path, feat_path):
opt = self.opt
# read label map
label_img = Image.open(label_path)
if label_path.find('face') != -1:
label_img = label_img.convert('L')
ow, oh = label_img.size
w = opt.loadSize
h = int(w * oh / ow)
label_img = label_img.resize((w, h), Image.NEAREST)
label_map = self.toTensor(label_img)
# onehot vector input for label map
self.label_map = label_map.cuda()
oneHot_size = (1, opt.label_nc, h, w)
input_label = self.Tensor(torch.Size(oneHot_size)).zero_()
self.input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0)
# read instance map
if not opt.no_instance:
inst_img = Image.open(inst_path)
inst_img = inst_img.resize((w, h), Image.NEAREST)
self.inst_map = self.toTensor(inst_img).cuda()
self.edge_map = self.get_edges(self.inst_map)
self.net_input = Variable(torch.cat((self.input_label, self.edge_map), dim=1), volatile=True)
else:
self.net_input = Variable(self.input_label, volatile=True)
self.features_clustered = np.load(feat_path).item()
self.object_map = self.inst_map if opt.instance_feat else self.label_map
object_np = self.object_map.cpu().numpy().astype(int)
self.feat_map = self.Tensor(1, opt.feat_num, h, w).zero_()
self.cluster_indices = np.zeros(self.opt.label_nc, np.uint8)
for i in np.unique(object_np):
label = i if i < 1000 else i//1000
if label in self.features_clustered:
feat = self.features_clustered[label]
np.random.seed(i+1)
cluster_idx = np.random.randint(0, feat.shape[0])
self.cluster_indices[label] = cluster_idx
idx = (self.object_map == i).nonzero()
self.set_features(idx, feat, cluster_idx)
self.net_input_original = self.net_input.clone()
self.label_map_original = self.label_map.clone()
self.feat_map_original = self.feat_map.clone()
if not opt.no_instance:
self.inst_map_original = self.inst_map.clone()
def reset(self):
self.net_input = self.net_input_prev = self.net_input_original.clone()
self.label_map = self.label_map_prev = self.label_map_original.clone()
self.feat_map = self.feat_map_prev = self.feat_map_original.clone()
if not self.opt.no_instance:
self.inst_map = self.inst_map_prev = self.inst_map_original.clone()
self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
def undo(self):
self.net_input = self.net_input_prev
self.label_map = self.label_map_prev
self.feat_map = self.feat_map_prev
if not self.opt.no_instance:
self.inst_map = self.inst_map_prev
self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
# get boundary map from instance map
def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
return edge.float()
# change the label at the source position to the label at the target position
def change_labels(self, click_src, click_tgt):
y_src, x_src = click_src[0], click_src[1]
y_tgt, x_tgt = click_tgt[0], click_tgt[1]
label_src = int(self.label_map[0, 0, y_src, x_src])
inst_src = self.inst_map[0, 0, y_src, x_src]
label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt])
inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt]
idx_src = (self.inst_map == inst_src).nonzero()
# need to change 3 things: label map, instance map, and feature map
if idx_src.shape:
# backup current maps
self.backup_current_state()
# change both the label map and the network input
self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0
self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
# update the instance map (and the network input)
if inst_tgt > 1000:
# if different instances have different ids, give the new object a new id
tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000)
inst_tgt = self.inst_map[tgt_indices].max() + 1
self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt
self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
# also copy the source features to the target position
idx_tgt = (self.inst_map == inst_tgt).nonzero()
if idx_tgt.shape:
self.copy_features(idx_src, idx_tgt[0,:])
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
# add strokes of target label in the image
def add_strokes(self, click_src, label_tgt, bw, save):
# get the region of the new strokes (bw is the brush width)
size = self.net_input.size()
h, w = size[2], size[3]
idx_src = torch.LongTensor(bw**2, 4).fill_(0)
for i in range(bw):
idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i))
for j in range(bw):
idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j))
idx_src = idx_src.cuda()
# again, need to update 3 things
if idx_src.shape:
# backup current maps
if save:
self.backup_current_state()
# update the label map (and the network input) in the stroke region
self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
for k in range(self.opt.label_nc):
self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
# update the instance map (and the network input)
self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
# also update the features if available
if self.opt.instance_feat:
feat = self.features_clustered[label_tgt]
#np.random.seed(label_tgt+1)
#cluster_idx = np.random.randint(0, feat.shape[0])
cluster_idx = self.cluster_indices[label_tgt]
self.set_features(idx_src, feat, cluster_idx)
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
# add an object to the clicked position with selected style
def add_objects(self, click_src, label_tgt, mask, style_id=0):
y, x = click_src[0], click_src[1]
mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...]
idx_src = torch.from_numpy(mask).cuda().nonzero()
idx_src[:,2] += y
idx_src[:,3] += x
# backup current maps
self.backup_current_state()
# update label map
self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
for k in range(self.opt.label_nc):
self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
# update instance map
self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
# update feature map
self.set_features(idx_src, self.feat, style_id)
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
def single_forward(self, net_input, feat_map):
net_input = torch.cat((net_input, feat_map), dim=1)
fake_image = self.netG.forward(net_input)
if fake_image.size()[0] == 1:
return fake_image.data[0]
return fake_image.data
# generate all outputs for different styles
def style_forward(self, click_pt, style_id=-1):
if click_pt is None:
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
self.crop = None
self.mask = None
else:
instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]])
self.instToChange = instToChange
label = instToChange if instToChange < 1000 else instToChange//1000
self.feat = self.features_clustered[label]
self.fake_image = []
self.mask = self.object_map == instToChange
idx = self.mask.nonzero()
self.get_crop_region(idx)
if idx.size():
if style_id == -1:
(min_y, min_x, max_y, max_x) = self.crop
### original
for cluster_idx in range(self.opt.multiple_output):
self.set_features(idx, self.feat, cluster_idx)
fake_image = self.single_forward(self.net_input, self.feat_map)
fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x])
self.fake_image.append(fake_image)
"""### To speed up previewing different style results, either crop or downsample the label maps
if instToChange > 1000:
(min_y, min_x, max_y, max_x) = self.crop
### crop
_, _, h, w = self.net_input.size()
offset = 512
y_start, x_start = max(0, min_y-offset), max(0, min_x-offset)
y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset))
y_region = slice(y_start, y_start+(y_end-y_start)//16*16)
x_region = slice(x_start, x_start+(x_end-x_start)//16*16)
net_input = self.net_input[:,:,y_region,x_region]
for cluster_idx in range(self.opt.multiple_output):
self.set_features(idx, self.feat, cluster_idx)
fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region])
fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start])
self.fake_image.append(fake_image)
else:
### downsample
(min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop]
net_input = self.net_input[:,:,::2,::2]
size = net_input.size()
net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3])
for cluster_idx in range(self.opt.multiple_output):
self.set_features(idx, self.feat, cluster_idx)
feat_map = self.feat_map[:,:,::2,::2]
if cluster_idx == 0:
feat_map_batch = feat_map
else:
feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0)
fake_image_batch = self.single_forward(net_input_batch, feat_map_batch)
for i in range(self.opt.multiple_output):
self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))"""
else:
self.set_features(idx, self.feat, style_id)
self.cluster_indices[label] = style_id
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
def backup_current_state(self):
self.net_input_prev = self.net_input.clone()
self.label_map_prev = self.label_map.clone()
self.inst_map_prev = self.inst_map.clone()
self.feat_map_prev = self.feat_map.clone()
# crop the ROI and get the mask of the object
def get_crop_region(self, idx):
size = self.net_input.size()
h, w = size[2], size[3]
min_y, min_x = idx[:,2].min(), idx[:,3].min()
max_y, max_x = idx[:,2].max(), idx[:,3].max()
crop_min = 128
if max_y - min_y < crop_min:
min_y = max(0, (max_y + min_y) // 2 - crop_min // 2)
max_y = min(h-1, min_y + crop_min)
if max_x - min_x < crop_min:
min_x = max(0, (max_x + min_x) // 2 - crop_min // 2)
max_x = min(w-1, min_x + crop_min)
self.crop = (min_y, min_x, max_y, max_x)
self.mask = self.mask[:,:, min_y:max_y, min_x:max_x]
# update the feature map once a new object is added or the label is changed
def update_features(self, cluster_idx, mask=None, click_pt=None):
self.feat_map_prev = self.feat_map.clone()
# adding a new object
if mask is not None:
y, x = click_pt[0], click_pt[1]
mask = np.transpose(mask, (2,0,1))[np.newaxis,...]
idx = torch.from_numpy(mask).cuda().nonzero()
idx[:,2] += y
idx[:,3] += x
# changing the label of an existing object
else:
idx = (self.object_map == self.instToChange).nonzero()
# update feature map
self.set_features(idx, self.feat, cluster_idx)
# set the class features to the target feature
def set_features(self, idx, feat, cluster_idx):
for k in range(self.opt.feat_num):
self.feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
# copy the features at the target position to the source position
def copy_features(self, idx_src, idx_tgt):
for k in range(self.opt.feat_num):
val = self.feat_map[idx_tgt[0], idx_tgt[1] + k, idx_tgt[2], idx_tgt[3]]
self.feat_map[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = val
def get_current_visuals(self, getLabel=False):
mask = self.mask
if self.mask is not None:
mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8)
dict_list = [('fake_image', self.fake_image), ('mask', mask)]
if getLabel: # only output label map if needed to save bandwidth
label = util.tensor2label(self.net_input.data[0], self.opt.label_nc)
dict_list += [('label', label)]
return OrderedDict(dict_list)
\ No newline at end of file
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
# experiment specifics
self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')
self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
# for setting inputs
self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
# for displays
self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
# for generator
self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
# for instance-wise features
self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
self.initialized = True
def parse(self, save=True):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
if save and not self.opt.continue_train:
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
self.isTrain = False
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
# for displays
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
# for training
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
# for discriminators
self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
self.isTrain = True
<img src='imgs/teaser_720.gif' align="right" width=360>
<br><br><br><br>
# pix2pixHD
### [Project](https://tcwang0509.github.io/pix2pixHD/) | [Youtube](https://youtu.be/3AIpPlzM_qs) | [Paper](https://arxiv.org/pdf/1711.11585.pdf) <br>
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps. <br><br>
[High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/)
[Ting-Chun Wang](https://tcwang0509.github.io/)<sup>1</sup>, [Ming-Yu Liu](http://mingyuliu.net/)<sup>1</sup>, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)<sup>2</sup>, Andrew Tao<sup>1</sup>, [Jan Kautz](http://jankautz.com/)<sup>1</sup>, [Bryan Catanzaro](http://catanzaro.name/)<sup>1</sup>
<sup>1</sup>NVIDIA Corporation, <sup>2</sup>UC Berkeley
In CVPR 2018.
## Image-to-image translation at 2k/1k resolution
- Our label-to-streetview results
<p align='center'>
<img src='imgs/teaser_label.png' width='400'/>
<img src='imgs/teaser_ours.jpg' width='400'/>
</p>
- Interactive editing results
<p align='center'>
<img src='imgs/teaser_style.gif' width='400'/>
<img src='imgs/teaser_label.gif' width='400'/>
</p>
- Additional streetview results
<p align='center'>
<img src='imgs/cityscapes_1.jpg' width='400'/>
<img src='imgs/cityscapes_2.jpg' width='400'/>
</p>
<p align='center'>
<img src='imgs/cityscapes_3.jpg' width='400'/>
<img src='imgs/cityscapes_4.jpg' width='400'/>
</p>
- Label-to-face and interactive editing results
<p align='center'>
<img src='imgs/face1_1.jpg' width='250'/>
<img src='imgs/face1_2.jpg' width='250'/>
<img src='imgs/face1_3.jpg' width='250'/>
</p>
<p align='center'>
<img src='imgs/face2_1.jpg' width='250'/>
<img src='imgs/face2_2.jpg' width='250'/>
<img src='imgs/face2_3.jpg' width='250'/>
</p>
- Our editing interface
<p align='center'>
<img src='imgs/city_short.gif' width='330'/>
<img src='imgs/face_short.gif' width='450'/>
</p>
## Prerequisites
- Linux or macOS
- Python 2 or 3
- NVIDIA GPU (11G memory or larger) + CUDA cuDNN
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org
- Install python libraries [dominate](https://github.com/Knio/dominate).
```bash
pip install dominate
```
- Clone this repo:
```bash
git clone https://github.com/NVIDIA/pix2pixHD
cd pix2pixHD
```
### Testing
- A few example Cityscapes test images are included in the `datasets` folder.
- Please download the pre-trained Cityscapes model from [here](https://drive.google.com/file/d/1h9SykUnuZul7J3Nbms2QGH1wa85nbN2-/view?usp=sharing) (google drive link), and put it under `./checkpoints/label2city_1024p/`
- Test the model (`bash ./scripts/test_1024p.sh`):
```bash
#!./scripts/test_1024p.sh
python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
```
The test results will be saved to a html file here: `./results/label2city_1024p/test_latest/index.html`.
More example scripts can be found in the `scripts` directory.
### Dataset
- We use the Cityscapes dataset. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
After downloading, please put it under the `datasets` folder in the same way the example images are provided.
### Training
- Train a model at 1024 x 512 resolution (`bash ./scripts/train_512p.sh`):
```bash
#!./scripts/train_512p.sh
python train.py --name label2city_512p
```
- To view training results, please checkout intermediate results in `./checkpoints/label2city_512p/web/index.html`.
If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/label2city_512p/logs` by adding `--tf_log` to the training scripts.
### Multi-GPU training
- Train a model using multiple GPUs (`bash ./scripts/train_512p_multigpu.sh`):
```bash
#!./scripts/train_512p_multigpu.sh
python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
```
Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion.
### Training with Automatic Mixed Precision (AMP) for faster speed
- To train with mixed precision support, please first install apex from: https://github.com/NVIDIA/apex
- You can then train the model by adding `--fp16`. For example,
```bash
#!./scripts/train_512p_fp16.sh
python -m torch.distributed.launch train.py --name label2city_512p --fp16
```
In our test case, it trains about 80% faster with AMP on a Volta machine.
### Training at full resolution
- To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`), or 16G memory if using mixed precision (AMP).
- If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
### Training with your own dataset
- If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing.
- If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input. The folders should then be named `train_A`, `train_B` instead of `train_label`, `train_img`, where the goal is to translate images from A to B.
- If you don't have instance maps or don't want to use them, please specify `--no_instance`.
- The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
## More Training/Test Details
- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
- Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`.
## Citation
If you find this useful for your research, please use the following.
```
@inproceedings{wang2018pix2pixHD,
title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs},
author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2018}
}
```
## Acknowledgments
This code borrows heavily from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import os
import util.util as util
from torch.autograd import Variable
import torch.nn as nn
opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True
name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)
############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))
######## Save precomputed feature maps for 1024p training #######
for i, data in enumerate(dataset):
print('%d / %d images' % (i+1, dataset_size))
feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
image_numpy = util.tensor2im(feat_map.data[0])
save_path = data['path'][0].replace('/train_label/', '/train_feat/')
util.save_image(image_numpy, save_path)
\ No newline at end of file
import os
import sys
from random import randint
import numpy as np
import tensorrt
try:
from PIL import Image
import pycuda.driver as cuda
import pycuda.gpuarray as gpuarray
import pycuda.autoinit
import argparse
except ImportError as err:
sys.stderr.write("""ERROR: failed to import module ({})
Please make sure you have pycuda and the example dependencies installed.
https://wiki.tiker.net/PyCuda/Installation/Linux
pip(3) install tensorrt[examples]
""".format(err))
exit(1)
try:
import tensorrt as trt
from tensorrt.parsers import caffeparser
from tensorrt.parsers import onnxparser
except ImportError as err:
sys.stderr.write("""ERROR: failed to import module ({})
Please make sure you have the TensorRT Library installed
and accessible in your LD_LIBRARY_PATH
""".format(err))
exit(1)
G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)
class Profiler(trt.infer.Profiler):
"""
Example Implimentation of a Profiler
Is identical to the Profiler class in trt.infer so it is possible
to just use that instead of implementing this if further
functionality is not needed
"""
def __init__(self, timing_iter):
trt.infer.Profiler.__init__(self)
self.timing_iterations = timing_iter
self.profile = []
def report_layer_time(self, layerName, ms):
record = next((r for r in self.profile if r[0] == layerName), (None, None))
if record == (None, None):
self.profile.append((layerName, ms))
else:
self.profile[self.profile.index(record)] = (record[0], record[1] + ms)
def print_layer_times(self):
totalTime = 0
for i in range(len(self.profile)):
print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations))
totalTime += self.profile[i][1]
print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations))
def get_input_output_names(trt_engine):
nbindings = trt_engine.get_nb_bindings();
maps = []
for b in range(0, nbindings):
dims = trt_engine.get_binding_dimensions(b).to_DimsCHW()
name = trt_engine.get_binding_name(b)
type = trt_engine.get_binding_data_type(b)
if (trt_engine.binding_is_input(b)):
maps.append(name)
print("Found input: ", name)
else:
maps.append(name)
print("Found output: ", name)
print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
print("dtype=" + str(type))
return maps
def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx):
binding_idx = engine.get_binding_index(name)
if binding_idx == -1:
raise AttributeError("Not a valid binding")
print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx)))
dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW()
eltCount = dims.C() * dims.H() * dims.W() * batchsize
if engine.binding_is_input(binding_idx):
h_mem = inp[inp_idx]
inp_idx = inp_idx + 1
else:
h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4'))
d_mem = cuda.mem_alloc(eltCount * 4)
cuda.memcpy_htod(d_mem, h_mem)
buf.insert(binding_idx, int(d_mem))
mem.append(d_mem)
return inp_idx
#Run inference on device
def time_inference(engine, batch_size, inp):
bindings = []
mem = []
inp_idx = 0
for io in get_input_output_names(engine):
inp_idx = create_memory(engine, io, bindings, mem,
batch_size, inp, inp_idx)
context = engine.create_execution_context()
g_prof = Profiler(500)
context.set_profiler(g_prof)
for i in range(iter):
context.execute(batch_size, bindings)
g_prof.print_layer_times()
context.destroy()
return
def convert_to_datatype(v):
if v==8:
return trt.infer.DataType.INT8
elif v==16:
return trt.infer.DataType.HALF
elif v==32:
return trt.infer.DataType.FLOAT
else:
print("ERROR: Invalid model data type bit depth: " + str(v))
return trt.infer.DataType.INT8
def run_trt_engine(engine_file, bs, it):
engine = trt.utils.load_engine(G_LOGGER, engine_file)
time_inference(engine, bs, it)
def run_onnx(onnx_file, data_type, bs, inp):
# Create onnx_config
apex = onnxparser.create_onnxconfig()
apex.set_model_file_name(onnx_file)
apex.set_model_dtype(convert_to_datatype(data_type))
# create parser
trt_parser = onnxparser.create_onnxparser(apex)
assert(trt_parser)
data_type = apex.get_model_dtype()
onnx_filename = apex.get_model_file_name()
trt_parser.parse(onnx_filename, data_type)
trt_parser.report_parsing_info()
trt_parser.convert_to_trtnetwork()
trt_network = trt_parser.get_trtnetwork()
assert(trt_network)
# create infer builder
trt_builder = trt.infer.create_infer_builder(G_LOGGER)
trt_builder.set_max_batch_size(max_batch_size)
trt_builder.set_max_workspace_size(max_workspace_size)
if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
print("------------------- Running FP16 -----------------------------")
trt_builder.set_half2_mode(True)
elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
print("------------------- Running INT8 -----------------------------")
trt_builder.set_int8_mode(True)
else:
print("------------------- Running FP32 -----------------------------")
print("----- Builder is Done -----")
print("----- Creating Engine -----")
trt_engine = trt_builder.build_cuda_engine(trt_network)
print("----- Engine is built -----")
time_inference(engine, bs, inp)
#!/bin/bash
################################ Testing ################################
# labels only
python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none $@
################################ Testing ################################
# first precompute and cluster all features
python encode_features.py --name label2city_1024p_feat --netG local --ngf 32 --resize_or_crop none;
# use instance-wise features
python test.py --name label2city_1024p_feat ---netG local --ngf 32 --resize_or_crop none --instance_feat
\ No newline at end of file
################################ Testing ################################
# labels only
python test.py --name label2city_512p
\ No newline at end of file
################################ Testing ################################
# first precompute and cluster all features
python encode_features.py --name label2city_512p_feat;
# use instance-wise features
python test.py --name label2city_512p_feat --instance_feat
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment