Commit 05f0839a authored by dengjb's avatar dengjb
Browse files

update

parents
# 模型唯一标识
modelCode=1819
# 模型名称
modelName=CycleGan_pytorch
# 模型描述
modelDescription=一种基于对抗生成网络的非配对图像转换模型.
# 应用场景
processType=推理
# 算法类别
appScenario=图像生成
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=K100AI
\ No newline at end of file
require 'nngraph'
----------------------------------------------------------------------------
local function weights_init(m)
local name = torch.type(m)
if name:find('Convolution') then
m.weight:normal(0.0, 0.02)
m.bias:fill(0)
elseif name:find('Normalization') then
if m.weight then m.weight:normal(1.0, 0.02) end
if m.bias then m.bias:fill(0) end
end
end
normalization = nil
function set_normalization(norm)
if norm == 'instance' then
require 'util.InstanceNormalization'
print('use InstanceNormalization')
normalization = nn.InstanceNormalization
elseif norm == 'batch' then
print('use SpatialBatchNormalization')
normalization = nn.SpatialBatchNormalization
end
end
function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch)
local netG = nil
if which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf)
elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf)
elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf)
elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf)
else error("unsupported netG model")
end
netG:apply(weights_init)
return netG
end
function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid)
local netD = nil
if which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid)
elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid)
elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid)
else error("unsupported netD model")
end
netD:apply(weights_init)
return netD
end
function defineG_encoder_decoder(input_nc, output_nc, ngf)
-- input is (nc) x 256 x 256
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 128 x 128
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
-- input is (ngf) x128 x 128
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 256 x 256
local o1 = d8 - nn.Tanh()
local netG = nn.gModule({e1},{o1})
return netG
end
function defineG_unet128(input_nc, output_nc, ngf)
local netG = nil
-- input is (nc) x 128 x 128
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 64 x 64
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 32 x 32
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 16 x 16
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d1 = {d1_,e6} - nn.JoinTable(2)
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d2 = {d2_,e5} - nn.JoinTable(2)
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d3 = {d3_,e4} - nn.JoinTable(2)
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 8) x 16 x 16
local d4 = {d4_,e3} - nn.JoinTable(2)
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 4) x 32 x 32
local d5 = {d5_,e2} - nn.JoinTable(2)
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
-- input is (ngf * 2) x 64 x 64
local d6 = {d6_,e1} - nn.JoinTable(2)
local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 128 x 128
local o1 = d7 - nn.Tanh()
local netG = nn.gModule({e1},{o1})
return netG
end
function defineG_unet256(input_nc, output_nc, ngf)
local netG = nil
-- input is (nc) x 256 x 256
local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
-- input is (ngf) x 128 x 128
local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 8 x 8
local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 4 x 4
local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 2 x 2
local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8)
-- input is (ngf * 8) x 1 x 1
local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 2 x 2
local d1 = {d1_,e7} - nn.JoinTable(2)
local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 4 x 4
local d2 = {d2_,e6} - nn.JoinTable(2)
local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
-- input is (ngf * 8) x 8 x 8
local d3 = {d3_,e5} - nn.JoinTable(2)
local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
-- input is (ngf * 8) x 16 x 16
local d4 = {d4_,e4} - nn.JoinTable(2)
local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
-- input is (ngf * 4) x 32 x 32
local d5 = {d5_,e3} - nn.JoinTable(2)
local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
-- input is (ngf * 2) x 64 x 64
local d6 = {d6_,e2} - nn.JoinTable(2)
local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
-- input is (ngf) x128 x 128
local d7 = {d7_,e1} - nn.JoinTable(2)
local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
-- input is (nc) x 256 x 256
local o1 = d8 - nn.Tanh()
local netG = nn.gModule({e1},{o1})
return netG
end
--------------------------------------------------------------------------------
-- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
--------------------------------------------------------------------------------
local function build_conv_block(dim, padding_type)
local conv_block = nn.Sequential()
local p = 0
if padding_type == 'reflect' then
conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
elseif padding_type == 'replicate' then
conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
elseif padding_type == 'zero' then
p = 1
end
conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
conv_block:add(normalization(dim))
conv_block:add(nn.ReLU(true))
if padding_type == 'reflect' then
conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
elseif padding_type == 'replicate' then
conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
end
conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
conv_block:add(normalization(dim))
return conv_block
end
local function build_res_block(dim, padding_type)
local conv_block = build_conv_block(dim, padding_type)
local res_block = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(conv_block)
concat:add(nn.Identity())
res_block:add(concat):add(nn.CAddTable())
return res_block
end
function defineG_resnet_6blocks(input_nc, output_nc, ngf)
padding_type = 'reflect'
local ks = 3
local netG = nil
local f = 7
local p = (f - 1) / 2
local data = -nn.Identity()
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
netG = nn.gModule({data},{d4})
return netG
end
function defineG_resnet_9blocks(input_nc, output_nc, ngf)
padding_type = 'reflect'
local ks = 3
local netG = nil
local f = 7
local p = (f - 1) / 2
local data = -nn.Identity()
local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
- build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
netG = nn.gModule({data},{d4})
return netG
end
function defineD_imageGAN(input_nc, ndf, use_sigmoid)
local netD = nn.Sequential()
-- input is (nc) x 256 x 256
netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1))
netD:add(nn.LeakyReLU(0.2, true))
-- state size: (ndf) x 128 x 128
netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*2) x 64 x 64
netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*4) x 32 x 32
netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 16 x 16
netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 8 x 8
netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 4 x 4
netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1))
-- state size: 1 x 1 x 1
if use_sigmoid then
netD:add(nn.Sigmoid())
end
return netD
end
function defineD_basic(input_nc, ndf, use_sigmoid)
n_layers = 3
return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid)
end
-- rf=1
function defineD_pixelGAN(input_nc, ndf, use_sigmoid)
local netD = nn.Sequential()
-- input is (nc) x 256 x 256
netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0))
netD:add(nn.LeakyReLU(0.2, true))
-- state size: (ndf) x 256 x 256
netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0))
netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*2) x 256 x 256
netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0))
-- state size: 1 x 256 x 256
if use_sigmoid then
netD:add(nn.Sigmoid())
-- state size: 1 x 30 x 30
end
return netD
end
-- if n=0, then use pixelGAN (rf=1)
-- else rf is 16 if n=1
-- 34 if n=2
-- 70 if n=3
-- 142 if n=4
-- 286 if n=5
-- 574 if n=6
function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio)
if dropout_ratio == nil then
dropout_ratio = 0.0
end
if kw == nil then
kw = 4
end
padw = math.ceil((kw-1)/2)
if n_layers==0 then
return defineD_pixelGAN(input_nc, ndf, use_sigmoid)
else
local netD = nn.Sequential()
-- input is (nc) x 256 x 256
-- print('input_nc', input_nc)
netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw))
netD:add(nn.LeakyReLU(0.2, true))
local nf_mult = 1
local nf_mult_prev = 1
for n = 1, n_layers-1 do
nf_mult_prev = nf_mult
nf_mult = math.min(2^n,8)
netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw))
netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio))
netD:add(nn.LeakyReLU(0.2, true))
end
-- state size: (ndf*M) x N x N
nf_mult_prev = nf_mult
nf_mult = math.min(2^n_layers,8)
netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw))
netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*M*2) x (N-1) x (N-1)
netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw))
-- state size: 1 x (N-2) x (N-2)
if use_sigmoid then
netD:add(nn.Sigmoid())
end
-- state size: 1 x (N-2) x (N-2)
return netD
end
end
--------------------------------------------------------------------------------
-- Base Class for Providing Models
--------------------------------------------------------------------------------
local class = require 'class'
BaseModel = class('BaseModel')
function BaseModel:__init(conf)
conf = conf or {}
end
-- Returns the name of the model
function BaseModel:model_name()
return 'DoesNothingModel'
end
-- Defines models and networks
function BaseModel:Initialize(opt)
models = {}
return models
end
-- Runs the forward pass of the network
function BaseModel:Forward(input, opt)
output = {}
return output
end
-- Runs the backprop gradient descent
-- Corresponds to a single batch of data
function BaseModel:OptimizeParameters(opt)
end
-- This function can be used to reset momentum after each epoch
function BaseModel:RefreshParameters(opt)
end
-- This function can be used to reset momentum after each epoch
function BaseModel:UpdateLearningRate(opt)
end
-- Save the current model to the file system
function BaseModel:Save(prefix, opt)
end
-- returns a string that describes the current errors
function BaseModel:GetCurrentErrorDescription()
return "No Error exists in BaseModel"
end
-- returns current errors
function BaseModel:GetCurrentErrors(opt)
return {}
end
-- returns a table of image/label pairs that describe
-- the current results.
-- |return|: a table of table. List of image/label pairs
function BaseModel:GetCurrentVisuals(opt, size)
return {}
end
-- returns a string that describes the display plot configuration
function BaseModel:DisplayPlot(opt)
return {}
end
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
content = paths.dofile('../util/content_loss.lua')
BiGANModel = class('BiGANModel', 'BaseModel')
function BiGANModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function BiGANModel:model_name()
return 'BiGANModel'
end
function BiGANModel:InitializeStates(use_wgan)
optimState = {learningRate=opt.lr, beta1=opt.beta1,}
return optimState
end
-- Defines models and networks
function BiGANModel:Initialize(opt)
if opt.test == 0 then
self.realABPool = ImagePool(opt.pool_size)
self.fakeABPool = ImagePool(opt.pool_size)
end
-- define tensors
local d_input_nc = opt.input_nc + opt.output_nc
self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.criterionGAN = nn.MSECriterion()
local netG, netE, netD = nil, nil, nil
if opt.continue_train == 1 then
if opt.test == 1 then -- which_epoch option exists in test mode
netG = util.load_test_model('G', opt)
netE = util.load_test_model('E', opt)
netD = util.load_test_model('D', opt)
else
netG = util.load_model('G', opt)
netE = util.load_model('E', opt)
netD = util.load_model('D', opt)
end
else
-- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
-- os.exit()
netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- no sigmoid layer
print('netD...', netD)
netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netG...', netG)
netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netE...', netE)
end
self.netD = netD
self.netG = netG
self.netE = netE
-- define real/fake labels
netD_output_size = self.netD:forward(self.real_AB):size()
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
self.optimStateD = self:InitializeStates()
self.optimStateG = self:InitializeStates()
self.optimStateE = self:InitializeStates()
self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G = %d'):format(self.parametersG:size(1)))
print(('E = %d'):format(self.parametersE:size(1)))
print(('D = %d'):format(self.parametersD:size(1)))
print('------------------------------------------------')
-- os.exit()
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function BiGANModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A
input.real_A = input.real_B
input.real_B = temp
end
self.real_AB[self.A_idx]:copy(input.real_A)
self.fake_AB[self.B_idx]:copy(input.real_B)
self.real_A = self.real_AB[self.A_idx]
self.real_B = self.fake_AB[self.B_idx]
self.fake_B = self.netG:forward(self.real_A):clone()
self.fake_A = self.netE:forward(self.real_B):clone()
self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B -> real_label
self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B -> fake_label
-- if opt.test == 0 then
-- self.real_AB = self.realABPool:Query(self.real_AB) -- batch history
-- self.fake_AB = self.fakeABPool:Query(self.fake_AB) -- batch history
-- end
end
-- create closure to evaluate f(X) and df/dX of discriminator
function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt)
util.BiasZero(netD)
gradParams:zero()
-- Real log(D_A(B))
local output = netD:forward(real_AB):clone()
local errD_real = self.criterionGAN:forward(output, self.real_label)
local df_do = self.criterionGAN:backward(output, self.real_label)
netD:backward(real_AB, df_do)
-- Fake + log(1 - D_A(G(A)))
output = netD:forward(fake_AB):clone()
local errD_fake = self.criterionGAN:forward(output, self.fake_label)
local df_do2 = self.criterionGAN:backward(output, self.fake_label)
netD:backward(fake_AB, df_do2)
-- Compute loss
local errD = (errD_real + errD_fake) / 2.0
return errD, gradParams
end
function BiGANModel:fDx(x, opt)
-- use image pool that stores the old fake images
real_AB = self.realABPool:Query(self.real_AB)
fake_AB = self.fakeABPool:Query(self.fake_AB)
self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt)
return self.errD, gradParams
end
function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt)
util.BiasZero(netG)
util.BiasZero(netD)
gradParametersG:zero()
-- First. G(A) should fake the discriminator
local output = netD:forward(self.real_AB):clone()
local errG = self.criterionGAN:forward(output, self.fake_label)
local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label)
local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd)
netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B -> real_label
return gradParametersG, errG
end
function BiGANModel:fGx(x, opt)
self.gradParametersG, self.errG = self:fGx_basic(x, self.netG, self.netD,
self.gradParametersG, opt)
return self.errG, self.gradParametersG
end
function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt)
util.BiasZero(netE)
util.BiasZero(netD)
gradParametersE:zero()
-- First. G(A) should fake the discriminator
local output = netD:forward(self.fake_AB):clone()
local errE= self.criterionGAN:forward(output, self.real_label)
local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd)
netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B -> fake_label
return gradParametersE, errE
end
function BiGANModel:fEx(x, opt)
self.gradParametersE, self.errE = self:fEx_basic(x, self.netE, self.netD,
self.gradParametersE, opt)
return self.errE, self.gradParametersE
end
function BiGANModel:OptimizeParameters(opt)
local fG = function(x) return self:fGx(x, opt) end
local fE = function(x) return self:fEx(x, opt) end
local fD = function(x) return self:fDx(x, opt) end
optim.adam(fD, self.parametersD, self.optimStateD)
optim.adam(fG, self.parametersG, self.optimStateG)
optim.adam(fE, self.parametersE, self.optimStateE)
end
function BiGANModel:RefreshParameters()
self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
self.parametersG, self.gradParametersG = nil, nil
self.parametersE, self.gradParametersE = nil, nil
-- define parameters of optimization
self.parametersD, self.gradParametersD = self.netD:getParameters()
self.parametersG, self.gradParametersG = self.netG:getParameters()
self.parametersE, self.gradParametersE = self.netE:getParameters()
end
function BiGANModel:Save(prefix, opt)
util.save_model(self.netG, prefix .. '_net_G.t7', 1)
util.save_model(self.netE, prefix .. '_net_E.t7', 1)
util.save_model(self.netD, prefix .. '_net_D.t7', 1)
end
function BiGANModel:GetCurrentErrorDescription()
description = ('D: %.4f G: %.4f E: %.4f'):format(
self.errD and self.errD or -1,
self.errG and self.errG or -1,
self.errE and self.errE or -1)
return description
end
function BiGANModel:GetCurrentErrors()
local errors = {errD=self.errD, errG=self.errG, errE=self.errE}
return errors
end
-- returns a string that describes the display plot configuration
function BiGANModel:DisplayPlot(opt)
return 'errD,errG,errE'
end
function BiGANModel:UpdateLearningRate(opt)
local lrd = opt.lr / opt.niter_decay
local old_lr = self.optimStateD['learningRate']
local lr = old_lr - lrd
self.optimStateD['learningRate'] = lr
self.optimStateG['learningRate'] = lr
self.optimStateE['learningRate'] = lr
print(('update learning rate: %f -> %f'):format(old_lr, lr))
end
local function MakeIm3(im)
-- print('before im_size', im:size())
local im3 = nil
if im:size(2) == 1 then
im3 = torch.repeatTensor(im, 1,3,1,1)
else
im3 = im
end
-- print('after im_size', im:size())
-- print('after im3_size', im3:size())
return im3
end
function BiGANModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
return visuals
end
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
content = paths.dofile('../util/content_loss.lua')
ContentGANModel = class('ContentGANModel', 'BaseModel')
function ContentGANModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function ContentGANModel:model_name()
return 'ContentGANModel'
end
function ContentGANModel:InitializeStates()
local optimState = {learningRate=opt.lr, beta1=opt.beta1,}
return optimState
end
-- Defines models and networks
function ContentGANModel:Initialize(opt)
if opt.test == 0 then
self.fakePool = ImagePool(opt.pool_size)
end
-- define tensors
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.criterionGAN = nn.MSECriterion()
self.criterionContent = nn.AbsCriterion()
self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name)
self.netG, self.netD = nil, nil
if opt.continue_train == 1 then
if opt.which_epoch then -- which_epoch option exists in test mode
self.netG = util.load_test_model('G_A', opt)
self.netD = util.load_test_model('D_A', opt)
else
self.netG = util.load_model('G_A', opt)
self.netD = util.load_model('D_A', opt)
end
else
self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
print('netG...', self.netG)
self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)
print('netD...', self.netD)
end
-- define real/fake labels
netD_output_size = self.netD:forward(self.real_A):size()
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
self.optimStateD = self:InitializeStates()
self.optimStateG = self:InitializeStates()
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G = %d'):format(self.parametersG:size(1)))
print(('D = %d'):format(self.parametersD:size(1)))
print('------------------------------------------------')
-- os.exit()
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function ContentGANModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A
input.real_A = input.real_B
input.real_B = temp
end
self.real_A:copy(input.real_A)
self.real_B:copy(input.real_B)
self.fake_B = self.netG:forward(self.real_A):clone()
-- output = {self.fake_B}
output = {}
-- if opt.test == 1 then
-- end
return output
end
-- create closure to evaluate f(X) and df/dX of discriminator
function ContentGANModel:fDx_basic(x, gradParams, netD, netG,
real_target, fake_target, opt)
util.BiasZero(netD)
util.BiasZero(netG)
gradParams:zero()
local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0
-- Real log(D_A(B))
local output = netD:forward(real_target)
errD_real = self.criterionGAN:forward(output, self.real_label)
df_do = self.criterionGAN:backward(output, self.real_label)
netD:backward(real_target, df_do)
-- Fake + log(1 - D_A(G_A(A)))
output = netD:forward(fake_target)
errD_fake = self.criterionGAN:forward(output, self.fake_label)
df_do = self.criterionGAN:backward(output, self.fake_label)
netD:backward(fake_target, df_do)
errD = (errD_real + errD_fake) / 2.0
-- print('errD', errD
return errD, gradParams
end
function ContentGANModel:fDx(x, opt)
fake_B = self.fakePool:Query(self.fake_B)
self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG,
self.real_B, fake_B, opt)
return self.errD, gradParams
end
function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target,
gradParametersG_source, opt)
util.BiasZero(netD_source)
util.BiasZero(netG_source)
gradParametersG_source:zero()
-- GAN loss
-- local df_d_GAN = torch.zeros(fake_target:size())
-- local errGAN = 0
-- local errRec = 0
--- Domain GAN loss: D_A(G_A(A))
local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) ---
local errGAN = self.criterionGAN:forward(output, self.real_label)
local df_do = self.criterionGAN:backward(output, self.real_label)
local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
-- content loss
-- print('content_loss', opt.content_loss)
-- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A)
netG_source:forward(real_source)
netG_source:backward(real_source, df_d_GAN + df_d_content)
-- print('errD', errGAN)
return gradParametersG_source, errGAN, errContent
end
function ContentGANModel:fGx(x, opt)
self.gradparametersG, self.errG, self.errCont =
self:fGx_basic(x, self.netG, self.netD,
self.real_A, self.real_B, self.fake_B,
self.gradparametersG, opt)
return self.errG, self.gradparametersG
end
function ContentGANModel:OptimizeParameters(opt)
local fDx = function(x) return self:fDx(x, opt) end
local fGx = function(x) return self:fGx(x, opt) end
optim.adam(fDx, self.parametersD, self.optimStateD)
optim.adam(fGx, self.parametersG, self.optimStateG)
end
function ContentGANModel:RefreshParameters()
self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory
self.parametersG, self.gradparametersG = nil, nil
-- define parameters of optimization
self.parametersG, self.gradparametersG = self.netG:getParameters()
self.parametersD, self.gradparametersD = self.netD:getParameters()
end
function ContentGANModel:Save(prefix, opt)
util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0)
util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0)
end
function ContentGANModel:GetCurrentErrorDescription()
description = ('G: %.4f D: %.4f Content: %.4f'):format(self.errG and self.errG or -1,
self.errD and self.errD or -1,
self.errCont and self.errCont or -1)
return description
end
function ContentGANModel:GetCurrentErrors()
local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1,
errCont=self.errCont and self.errCont or -1}
return errors
end
-- returns a string that describes the display plot configuration
function ContentGANModel:DisplayPlot(opt)
return 'errG,errD,errCont'
end
function ContentGANModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=self.real_A, label='real_A'})
table.insert(visuals, {img=self.fake_B, label='fake_B'})
table.insert(visuals, {img=self.real_B, label='real_B'})
return visuals
end
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
CycleGANModel = class('CycleGANModel', 'BaseModel')
function CycleGANModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function CycleGANModel:model_name()
return 'CycleGANModel'
end
function CycleGANModel:InitializeStates(use_wgan)
optimState = {learningRate=opt.lr, beta1=opt.beta1,}
return optimState
end
-- Defines models and networks
function CycleGANModel:Initialize(opt)
if opt.test == 0 then
self.fakeAPool = ImagePool(opt.pool_size)
self.fakeBPool = ImagePool(opt.pool_size)
end
-- define tensors
if opt.test == 0 then -- allocate tensors for training
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.real_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
self.fake_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
self.rec_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.rec_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
end
-- load/define models
local use_lsgan = ((opt.use_lsgan ~= nil) and (opt.use_lsgan == 1))
if not use_lsgan then
self.criterionGAN = nn.BCECriterion()
else
self.criterionGAN = nn.MSECriterion()
end
self.criterionRec = nn.AbsCriterion()
local netG_A, netD_A, netG_B, netD_B = nil, nil, nil, nil
if opt.continue_train == 1 then
if opt.test == 1 then -- test mode
netG_A = util.load_test_model('G_A', opt)
netG_B = util.load_test_model('G_B', opt)
--setup optnet to save a little bit of memory
if opt.use_optnet == 1 then
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
local optnet = require 'optnet'
optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true})
optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true})
end
else
netG_A = util.load_model('G_A', opt)
netG_B = util.load_model('G_B', opt)
netD_A = util.load_model('D_A', opt)
netD_B = util.load_model('D_B', opt)
end
else
local use_sigmoid = (not use_lsgan)
-- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
-- os.exit()
netG_A = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netG_A...', netG_A)
netD_A = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer
print('netD_A...', netD_A)
netG_B = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
print('netG_B...', netG_B)
netD_B = defineD(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer
print('netD_B', netD_B)
end
self.netD_A = netD_A
self.netG_A = netG_A
self.netG_B = netG_B
self.netD_B = netD_B
-- define real/fake labels
if opt.test == 0 then
local D_A_size = self.netD_A:forward(self.real_B):size() -- hack: assume D_size_A = D_size_B
self.fake_label_A = torch.Tensor(D_A_size):fill(0.0)
self.real_label_A = torch.Tensor(D_A_size):fill(1.0) -- no soft smoothing
local D_B_size = self.netD_B:forward(self.real_A):size() -- hack: assume D_size_A = D_size_B
self.fake_label_B = torch.Tensor(D_B_size):fill(0.0)
self.real_label_B = torch.Tensor(D_B_size):fill(1.0) -- no soft smoothing
self.optimStateD_A = self:InitializeStates()
self.optimStateG_A = self:InitializeStates()
self.optimStateD_B = self:InitializeStates()
self.optimStateG_B = self:InitializeStates()
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G_A = %d'):format(self.parametersG_A:size(1)))
print(('D_A = %d'):format(self.parametersD_A:size(1)))
print(('G_B = %d'):format(self.parametersG_B:size(1)))
print(('D_B = %d'):format(self.parametersD_B:size(1)))
print('------------------------------------------------')
end
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function CycleGANModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A:clone()
input.real_A = input.real_B:clone()
input.real_B = temp
end
if opt.test == 0 then
self.real_A:copy(input.real_A)
self.real_B:copy(input.real_B)
end
if opt.test == 1 then -- forward for test
if opt.gpu > 0 then
self.real_A = input.real_A:cuda()
self.real_B = input.real_B:cuda()
else
self.real_A = input.real_A:clone()
self.real_B = input.real_B:clone()
end
self.fake_B = self.netG_A:forward(self.real_A):clone()
self.fake_A = self.netG_B:forward(self.real_B):clone()
self.rec_A = self.netG_B:forward(self.fake_B):clone()
self.rec_B = self.netG_A:forward(self.fake_A):clone()
end
end
-- create closure to evaluate f(X) and df/dX of discriminator
function CycleGANModel:fDx_basic(x, gradParams, netD, netG, real, fake, real_label, fake_label, opt)
util.BiasZero(netD)
util.BiasZero(netG)
gradParams:zero()
-- Real log(D_A(B))
local output = netD:forward(real)
local errD_real = self.criterionGAN:forward(output, real_label)
local df_do = self.criterionGAN:backward(output, real_label)
netD:backward(real, df_do)
-- Fake + log(1 - D_A(G_A(A)))
output = netD:forward(fake)
local errD_fake = self.criterionGAN:forward(output, fake_label)
local df_do2 = self.criterionGAN:backward(output, fake_label)
netD:backward(fake, df_do2)
-- Compute loss
local errD = (errD_real + errD_fake) / 2.0
return errD, gradParams
end
function CycleGANModel:fDAx(x, opt)
-- use image pool that stores the old fake images
fake_B = self.fakeBPool:Query(self.fake_B)
self.errD_A, gradParams = self:fDx_basic(x, self.gradparametersD_A, self.netD_A, self.netG_A,
self.real_B, fake_B, self.real_label_A, self.fake_label_A, opt)
return self.errD_A, gradParams
end
function CycleGANModel:fDBx(x, opt)
-- use image pool that stores the old fake images
fake_A = self.fakeAPool:Query(self.fake_A)
self.errD_B, gradParams = self:fDx_basic(x, self.gradparametersD_B, self.netD_B, self.netG_B,
self.real_A, fake_A, self.real_label_B, self.fake_label_B, opt)
return self.errD_B, gradParams
end
function CycleGANModel:fGx_basic(x, gradParams, netG, netD, netE, real, real2, real_label, lambda1, lambda2, opt)
util.BiasZero(netD)
util.BiasZero(netG)
util.BiasZero(netE) -- inverse mapping
gradParams:zero()
-- G should be identity if real2 is fed.
local errI = nil
local identity = nil
if opt.lambda_identity > 0 then
identity = netG:forward(real2):clone()
errI = self.criterionRec:forward(identity, real2) * lambda2 * opt.lambda_identity
local didentity_loss_do = self.criterionRec:backward(identity, real2):mul(lambda2):mul(opt.lambda_identity)
netG:backward(real2, didentity_loss_do)
end
--- GAN loss: D_A(G_A(A))
local fake = netG:forward(real):clone()
local output = netD:forward(fake)
local errG = self.criterionGAN:forward(output, real_label)
local df_do1 = self.criterionGAN:backward(output, real_label)
local df_d_GAN = netD:updateGradInput(fake, df_do1) --
-- forward cycle loss
local rec = netE:forward(fake):clone()
local errRec = self.criterionRec:forward(rec, real) * lambda1
local df_do2 = self.criterionRec:backward(rec, real):mul(lambda1)
local df_do_rec = netE:updateGradInput(fake, df_do2)
netG:backward(real, df_d_GAN + df_do_rec)
-- backward cycle loss
local fake2 = netE:forward(real2)--:clone()
local rec2 = netG:forward(fake2)--:clone()
local errAdapt = self.criterionRec:forward(rec2, real2) * lambda2
local df_do_coadapt = self.criterionRec:backward(rec2, real2):mul(lambda2)
netG:backward(fake2, df_do_coadapt)
return gradParams, errG, errRec, errI, fake, rec, identity
end
function CycleGANModel:fGAx(x, opt)
self.gradparametersG_A, self.errG_A, self.errRec_A, self.errI_A, self.fake_B, self.rec_A, self.identity_B =
self:fGx_basic(x, self.gradparametersG_A, self.netG_A, self.netD_A, self.netG_B, self.real_A, self.real_B,
self.real_label_A, opt.lambda_A, opt.lambda_B, opt)
return self.errG_A, self.gradparametersG_A
end
function CycleGANModel:fGBx(x, opt)
self.gradparametersG_B, self.errG_B, self.errRec_B, self.errI_B, self.fake_A, self.rec_B, self.identity_A =
self:fGx_basic(x, self.gradparametersG_B, self.netG_B, self.netD_B, self.netG_A, self.real_B, self.real_A,
self.real_label_B, opt.lambda_B, opt.lambda_A, opt)
return self.errG_B, self.gradparametersG_B
end
function CycleGANModel:OptimizeParameters(opt)
local fDA = function(x) return self:fDAx(x, opt) end
local fGA = function(x) return self:fGAx(x, opt) end
local fDB = function(x) return self:fDBx(x, opt) end
local fGB = function(x) return self:fGBx(x, opt) end
optim.adam(fGA, self.parametersG_A, self.optimStateG_A)
optim.adam(fDA, self.parametersD_A, self.optimStateD_A)
optim.adam(fGB, self.parametersG_B, self.optimStateG_B)
optim.adam(fDB, self.parametersD_B, self.optimStateD_B)
end
function CycleGANModel:RefreshParameters()
self.parametersD_A, self.gradparametersD_A = nil, nil -- nil them to avoid spiking memory
self.parametersG_A, self.gradparametersG_A = nil, nil
self.parametersG_B, self.gradparametersG_B = nil, nil
self.parametersD_B, self.gradparametersD_B = nil, nil
-- define parameters of optimization
self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
self.parametersD_A, self.gradparametersD_A = self.netD_A:getParameters()
self.parametersG_B, self.gradparametersG_B = self.netG_B:getParameters()
self.parametersD_B, self.gradparametersD_B = self.netD_B:getParameters()
end
function CycleGANModel:Save(prefix, opt)
util.save_model(self.netG_A, prefix .. '_net_G_A.t7', 1)
util.save_model(self.netD_A, prefix .. '_net_D_A.t7', 1)
util.save_model(self.netG_B, prefix .. '_net_G_B.t7', 1)
util.save_model(self.netD_B, prefix .. '_net_D_B.t7', 1)
end
function CycleGANModel:GetCurrentErrorDescription()
description = ('[A] G: %.4f D: %.4f Rec: %.4f I: %.4f || [B] G: %.4f D: %.4f Rec: %.4f I:%.4f'):format(
self.errG_A and self.errG_A or -1,
self.errD_A and self.errD_A or -1,
self.errRec_A and self.errRec_A or -1,
self.errI_A and self.errI_A or -1,
self.errG_B and self.errG_B or -1,
self.errD_B and self.errD_B or -1,
self.errRec_B and self.errRec_B or -1,
self.errI_B and self.errI_B or -1)
return description
end
function CycleGANModel:GetCurrentErrors()
local errors = {errG_A=self.errG_A, errD_A=self.errD_A, errRec_A=self.errRec_A, errI_A=self.errI_A,
errG_B=self.errG_B, errD_B=self.errD_B, errRec_B=self.errRec_B, errI_B=self.errI_B}
return errors
end
-- returns a string that describes the display plot configuration
function CycleGANModel:DisplayPlot(opt)
if opt.lambda_identity > 0 then
return 'errG_A,errD_A,errRec_A,errI_A,errG_B,errD_B,errRec_B,errI_B'
else
return 'errG_A,errD_A,errRec_A,errG_B,errD_B,errRec_B'
end
end
function CycleGANModel:UpdateLearningRate(opt)
local lrd = opt.lr / opt.niter_decay
local old_lr = self.optimStateD_A['learningRate']
local lr = old_lr - lrd
self.optimStateD_A['learningRate'] = lr
self.optimStateD_B['learningRate'] = lr
self.optimStateG_A['learningRate'] = lr
self.optimStateG_B['learningRate'] = lr
print(('update learning rate: %f -> %f'):format(old_lr, lr))
end
local function MakeIm3(im)
if im:size(2) == 1 then
local im3 = torch.repeatTensor(im, 1,3,1,1)
return im3
else
return im
end
end
function CycleGANModel:GetCurrentVisuals(opt, size)
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
table.insert(visuals, {img=MakeIm3(self.rec_A), label='rec_A'})
if opt.test == 0 and opt.lambda_identity > 0 then
table.insert(visuals, {img=MakeIm3(self.identity_A), label='identity_A'})
end
table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
table.insert(visuals, {img=MakeIm3(self.rec_B), label='rec_B'})
if opt.test == 0 and opt.lambda_identity > 0 then
table.insert(visuals, {img=MakeIm3(self.identity_B), label='identity_B'})
end
return visuals
end
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')
function OneDirectionTestModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function OneDirectionTestModel:model_name()
return 'OneDirectionTestModel'
end
-- Defines models and networks
function OneDirectionTestModel:Initialize(opt)
-- define tensors
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.netG_A = util.load_test_model('G', opt)
-- setup optnet to save a bit of memory
if opt.use_optnet == 1 then
local optnet = require 'optnet'
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
end
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G_A = %d'):format(self.parametersG_A:size(1)))
print('------------------------------------------------')
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function OneDirectionTestModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
input.real_A = input.real_B:clone()
end
self.real_A = input.real_A:clone()
if opt.gpu > 0 then
self.real_A = self.real_A:cuda()
end
self.fake_B = self.netG_A:forward(self.real_A):clone()
end
function OneDirectionTestModel:RefreshParameters()
self.parametersG_A, self.gradparametersG_A = nil, nil
self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
end
local function MakeIm3(im)
if im:size(2) == 1 then
local im3 = torch.repeatTensor(im, 1,3,1,1)
return im3
else
return im
end
end
function OneDirectionTestModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
return visuals
end
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
Pix2PixModel = class('Pix2PixModel', 'BaseModel')
function Pix2PixModel:__init(conf)
conf = conf or {}
end
-- Returns the name of the model
function Pix2PixModel:model_name()
return 'Pix2PixModel'
end
function Pix2PixModel:InitializeStates()
return {learningRate=opt.lr, beta1=opt.beta1,}
end
-- Defines models and networks
function Pix2PixModel:Initialize(opt) -- use lsgan
-- define tensors
local d_input_nc = opt.input_nc + opt.output_nc
self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
if opt.test == 0 then
self.fakeABPool = ImagePool(opt.pool_size)
end
-- load/define models
self.criterionGAN = nn.MSECriterion()
self.criterionL1 = nn.AbsCriterion()
local netG, netD = nil, nil
if opt.continue_train == 1 then
if opt.test == 1 then -- only load model G for test
netG = util.load_test_model('G', opt)
else
netG = util.load_model('G', opt)
netD = util.load_model('D', opt)
end
else
netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- with sigmoid
end
self.netD = netD
self.netG = netG
-- define real/fake labels
if opt.test == 0 then
netD_output_size = self.netD:forward(self.real_AB):size()
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
self.optimStateD = self:InitializeStates()
self.optimStateG = self:InitializeStates()
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G = %d'):format(self.parametersG:size(1)))
print(('D = %d'):format(self.parametersD:size(1)))
print('------------------------------------------------')
end
self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
end
-- Runs the forward pass of the network
function Pix2PixModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
local temp = input.real_A
input.real_A = input.real_B
input.real_B = temp
end
if opt.test == 0 then
self.real_AB[self.A_idx]:copy(input.real_A)
self.real_AB[self.B_idx]:copy(input.real_B)
self.real_A = self.real_AB[self.A_idx]
self.real_B = self.real_AB[self.B_idx]
self.fake_AB[self.A_idx]:copy(self.real_A)
self.fake_B = self.netG:forward(self.real_A):clone()
self.fake_AB[self.B_idx]:copy(self.fake_B)
else
if opt.gpu > 0 then
self.real_A = input.real_A:cuda()
self.real_B = input.real_B:cuda()
else
self.real_A = input.real_A:clone()
self.real_B = input.real_B:clone()
end
self.fake_B = self.netG:forward(self.real_A):clone()
end
end
-- create closure to evaluate f(X) and df/dX of discriminator
function Pix2PixModel:fDx_basic(x, gradParams, netD, netG, real, fake, opt)
util.BiasZero(netD)
util.BiasZero(netG)
gradParams:zero()
-- Real log(D(B))
local output = netD:forward(real)
local errD_real = self.criterionGAN:forward(output, self.real_label)
local df_do = self.criterionGAN:backward(output, self.real_label)
netD:backward(real, df_do)
-- Fake + log(1 - D(G(A)))
output = netD:forward(fake)
local errD_fake = self.criterionGAN:forward(output, self.fake_label)
local df_do2 = self.criterionGAN:backward(output, self.fake_label)
netD:backward(fake, df_do2)
-- calculate loss
local errD = (errD_real + errD_fake) / 2.0
return errD, gradParams
end
function Pix2PixModel:fDx(x, opt)
fake_AB = self.fakeABPool:Query(self.fake_AB)
self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, self.netG,
self.real_AB, fake_AB, opt)
return self.errD, gradParams
end
function Pix2PixModel:fGx_basic(x, netG, netD, real, fake, gradParametersG, opt)
util.BiasZero(netG)
util.BiasZero(netD)
gradParametersG:zero()
-- First. G(A) should fake the discriminator
local output = netD:forward(fake)
local errG = self.criterionGAN:forward(output, self.real_label)
local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
local dgan_loss_do = netD:updateGradInput(fake, dgan_loss_dd)
-- Second. G(A) should be close to the real
real_B = real[self.B_idx]
real_A = real[self.A_idx]
fake_B = fake[self.B_idx]
local errL1 = self.criterionL1:forward(fake_B, real_B) * opt.lambda_A
local dl1_loss_do = self.criterionL1:backward(fake_B, real_B) * opt.lambda_A
netG:backward(real_A, dgan_loss_do[self.B_idx] + dl1_loss_do)
return gradParametersG, errG, errL1
end
function Pix2PixModel:fGx(x, opt)
self.gradParametersG, self.errG, self.errL1 = self:fGx_basic(x, self.netG, self.netD,
self.real_AB, self.fake_AB, self.gradParametersG, opt)
return self.errG, self.gradParametersG
end
-- Runs the backprop gradient descent
-- Corresponds to a single batch of data
function Pix2PixModel:OptimizeParameters(opt)
local fD = function(x) return self:fDx(x, opt) end
local fG = function(x) return self:fGx(x, opt) end
optim.adam(fD, self.parametersD, self.optimStateD)
optim.adam(fG, self.parametersG, self.optimStateG)
end
-- This function can be used to reset momentum after each epoch
function Pix2PixModel:RefreshParameters()
self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
self.parametersG, self.gradParametersG = nil, nil
-- define parameters of optimization
self.parametersG, self.gradParametersG = self.netG:getParameters()
self.parametersD, self.gradParametersD = self.netD:getParameters()
end
-- This function updates the learning rate; lr for the first opt.niter iterations; graduatlly decreases the lr to 0 for the next opt.niter_decay iterations
function Pix2PixModel:UpdateLearningRate(opt)
local lrd = opt.lr / opt.niter_decay
local old_lr = self.optimStateD['learningRate']
local lr = old_lr - lrd
self.optimStateD['learningRate'] = lr
self.optimStateG['learningRate'] = lr
print(('update learning rate: %f -> %f'):format(old_lr, lr))
end
-- Save the current model to the file system
function Pix2PixModel:Save(prefix, opt)
util.save_model(self.netG, prefix .. '_net_G.t7', 1.0)
util.save_model(self.netD, prefix .. '_net_D.t7', 1.0)
end
-- returns a string that describes the current errors
function Pix2PixModel:GetCurrentErrorDescription()
description = ('G: %.4f D: %.4f L1: %.4f'):format(
self.errG and self.errG or -1, self.errD and self.errD or -1, self.errL1 and self.errL1 or -1)
return description
end
-- returns a string that describes the display plot configuration
function Pix2PixModel:DisplayPlot(opt)
return 'errG,errD,errL1'
end
-- returns current errors
function Pix2PixModel:GetCurrentErrors()
local errors = {errG=self.errG, errD=self.errD, errL1=self.errL1}
return errors
end
-- returns a table of image/label pairs that describe
-- the current results.
-- |return|: a table of table. List of image/label pairs
function Pix2PixModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=self.real_A, label='real_A'})
table.insert(visuals, {img=self.fake_B, label='fake_B'})
table.insert(visuals, {img=self.real_B, label='real_B'})
return visuals
end
--------------------------------------------------------------------------------
-- Configure options
--------------------------------------------------------------------------------
local options = {}
-- options for train
local opt_train = {
DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
batchSize = 1, -- # images in batch
loadSize = 143, -- scale images to this size
fineSize = 128, -- then crop to this size
ngf = 64, -- # of gen filters in first conv layer
ndf = 64, -- # of discrim filters in first conv layer
input_nc = 3, -- # of input image channels
output_nc = 3, -- # of output image channels
niter = 100, -- # of iter at starting learning rate
niter_decay = 100, -- # of iter to linearly decay learning rate to zero
lr = 0.0002, -- initial learning rate for adam
beta1 = 0.5, -- momentum term of adam
ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset
flip = 1, -- if flip the images for data argumentation
display_id = 10, -- display window id.
display_winsize = 128, -- display window size
display_freq = 25, -- display the current results every display_freq iterations
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
name = '', -- name of the experiment, should generally be passed on the command line
which_direction = 'AtoB', -- AtoB or BtoA
phase = 'train', -- train, val, test, etc
nThreads = 2, -- # threads for loading data
save_epoch_freq = 1, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models)
save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model)
print_freq = 50, -- print the debug information every print_freq iterations
save_display_freq = 2500, -- save the current display of results every save_display_freq_iterations
continue_train = 0, -- if continue training, load the latest model: 1: true, 0: false
serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly
checkpoints_dir = './checkpoints', -- models are saved here
cache_dir = './cache', -- cache files are saved here
cudnn = 1, -- set to 0 to not use cudnn
which_model_netD = 'basic', -- selects model to use for netD
which_model_netG = 'resnet_6blocks', -- selects model to use for netG
norm = 'instance', -- batch or instance normalization
n_layers_D = 3, -- only used if which_model_netD=='n_layers'
content_loss = 'pixel', -- content loss type: pixel, vgg
layer_name = 'pixel', -- layer used in content loss (e.g. relu4_2)
lambda_A = 10.0, -- weight for cycle loss (A -> B -> A)
lambda_B = 10.0, -- weight for cycle loss (B -> A -> B)
model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'
use_lsgan = 1, -- if 1, use least square GAN, if 0, use vanilla GAN
align_data = 0, -- if > 0, use the dataloader for where the images are aligned
pool_size = 50, -- the size of image buffer that stores previously generated images
resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height
lambda_identity = 0.5, -- use identity mapping. Setting opt.lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.lambda_identity = 0.1
use_optnet = 0, -- use optnet to save GPU memory during test
}
-- options for test
local opt_test = {
DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
loadSize = 128, -- scale images to this size
fineSize = 128, -- then crop to this size
flip = 0, -- horizontal mirroring data augmentation
display = 1, -- display samples while training. 0 = false
display_id = 200, -- display window id.
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder)
phase = 'test', -- train, val, test, etc
aspect_ratio = 1.0, -- aspect ratio of result images
norm = 'instance', -- batchnorm or isntance norm
name = '', -- name of experiment, selects which model to run, should generally should be passed on command line
input_nc = 3, -- # of input image channels
output_nc = 3, -- # of output image channels
serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly
cudnn = 1, -- set to 0 to not use cudnn (untested)
checkpoints_dir = './checkpoints', -- loads models from here
cache_dir = './cache', -- cache files are saved here
results_dir='./results/', -- saves results here
which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model
model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'; to use pretrained model, select `one_direction_test`
align_data = 0, -- if > 0, use the dataloader for pix2pix
which_direction = 'AtoB', -- AtoB or BtoA
resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height
}
--------------------------------------------------------------------------------
-- util functions
--------------------------------------------------------------------------------
function options.clone(opt)
local copy = {}
for orig_key, orig_value in pairs(opt) do
copy[orig_key] = orig_value
end
return copy
end
function options.parse_options(mode)
if mode == 'train' then
opt = opt_train
opt.test = 0
elseif mode == 'test' then
opt = opt_test
opt.test = 1
else
print("Invalid option [" .. mode .. "]")
return nil
end
-- one-line argument parser. parses enviroment variables to override the defaults
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
if mode == 'test' then
opt.nThreads = 1
opt.continue_train = 1
opt.batchSize = 1 -- test code only supports batchSize=1
end
-- print by keys
keyset = {}
for k,v in pairs(opt) do
table.insert(keyset, k)
end
table.sort(keyset)
print("------------------- Options -------------------")
for i,k in ipairs(keyset) do
print(('%+25s: %s'):format(k, opt[k]))
end
print("-----------------------------------------------")
-- save opt to checkpoints
paths.mkdir(opt.checkpoints_dir)
paths.mkdir(paths.concat(opt.checkpoints_dir, opt.name))
opt.visual_dir = paths.concat(opt.checkpoints_dir, opt.name, 'visuals')
paths.mkdir(opt.visual_dir)
-- save opt to the disk
fd = io.open(paths.concat(opt.checkpoints_dir, opt.name, 'opt_' .. mode .. '.txt'), 'w')
for i,k in ipairs(keyset) do
fd:write(("%+25s: %s\n"):format(k, opt[k]))
end
fd:close()
return opt
end
return options
FILE=$1
echo "Note: available models are apple2orange, facades_photo2label, map2sat, orange2apple, style_cezanne, style_ukiyoe, summer2winter_yosemite, zebra2horse, facades_label2photo, horse2zebra,monet2photo, sat2map, style_monet,style_vangogh, winter2summer_yosemite, iphone2dslr_flower"
echo "Specified [$FILE]"
mkdir -p ./checkpoints/${FILE}_pretrained
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/models/$FILE.t7
MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.t7
wget -N $URL -O $MODEL_FILE
URL1=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.caffemodel
MODEL_FILE1=./models/places_vgg.caffemodel
URL2=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.prototxt
MODEL_FILE2=./models/places_vgg.prototxt
wget -N $URL1 -O $MODEL_FILE1
wget -N $URL2 -O $MODEL_FILE2
name: "VGG-Places365"
input: "data"
input_dim: 1
input_dim: 3
input_dim: 224
input_dim: 224
layer {
name: "conv1_1"
type: "Convolution"
bottom: "data"
top: "conv1_1"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 64
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu1_1"
type: "ReLU"
bottom: "conv1_1"
top: "conv1_1"
}
layer {
name: "conv1_2"
type: "Convolution"
bottom: "conv1_1"
top: "conv1_2"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 64
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu1_2"
type: "ReLU"
bottom: "conv1_2"
top: "conv1_2"
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1_2"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv2_1"
type: "Convolution"
bottom: "pool1"
top: "conv2_1"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 128
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu2_1"
type: "ReLU"
bottom: "conv2_1"
top: "conv2_1"
}
layer {
name: "conv2_2"
type: "Convolution"
bottom: "conv2_1"
top: "conv2_2"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 128
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu2_2"
type: "ReLU"
bottom: "conv2_2"
top: "conv2_2"
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2_2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv3_1"
type: "Convolution"
bottom: "pool2"
top: "conv3_1"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 256
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu3_1"
type: "ReLU"
bottom: "conv3_1"
top: "conv3_1"
}
layer {
name: "conv3_2"
type: "Convolution"
bottom: "conv3_1"
top: "conv3_2"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 256
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu3_2"
type: "ReLU"
bottom: "conv3_2"
top: "conv3_2"
}
layer {
name: "conv3_3"
type: "Convolution"
bottom: "conv3_2"
top: "conv3_3"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 256
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu3_3"
type: "ReLU"
bottom: "conv3_3"
top: "conv3_3"
}
layer {
name: "pool3"
type: "Pooling"
bottom: "conv3_3"
top: "pool3"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv4_1"
type: "Convolution"
bottom: "pool3"
top: "conv4_1"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu4_1"
type: "ReLU"
bottom: "conv4_1"
top: "conv4_1"
}
layer {
name: "conv4_2"
type: "Convolution"
bottom: "conv4_1"
top: "conv4_2"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu4_2"
type: "ReLU"
bottom: "conv4_2"
top: "conv4_2"
}
layer {
name: "conv4_3"
type: "Convolution"
bottom: "conv4_2"
top: "conv4_3"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu4_3"
type: "ReLU"
bottom: "conv4_3"
top: "conv4_3"
}
layer {
name: "pool4"
type: "Pooling"
bottom: "conv4_3"
top: "pool4"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv5_1"
type: "Convolution"
bottom: "pool4"
top: "conv5_1"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu5_1"
type: "ReLU"
bottom: "conv5_1"
top: "conv5_1"
}
layer {
name: "conv5_2"
type: "Convolution"
bottom: "conv5_1"
top: "conv5_2"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu5_2"
type: "ReLU"
bottom: "conv5_2"
top: "conv5_2"
}
layer {
name: "conv5_3"
type: "Convolution"
bottom: "conv5_2"
top: "conv5_3"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu5_3"
type: "ReLU"
bottom: "conv5_3"
top: "conv5_3"
}
layer {
name: "pool5"
type: "Pooling"
bottom: "conv5_3"
top: "pool5"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "fc6"
type: "InnerProduct"
bottom: "pool5"
top: "fc6"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
inner_product_param {
num_output: 4096
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu6"
type: "ReLU"
bottom: "fc6"
top: "fc6"
}
layer {
name: "drop6"
type: "Dropout"
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc7"
type: "InnerProduct"
bottom: "fc6"
top: "fc7"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
inner_product_param {
num_output: 4096
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
layer {
name: "relu7"
type: "ReLU"
bottom: "fc7"
top: "fc7"
}
layer {
name: "drop7"
type: "Dropout"
bottom: "fc7"
top: "fc7"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc8a"
type: "InnerProduct"
bottom: "fc7"
top: "fc8a"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
inner_product_param {
num_output: 365
}
}
layer {
name: "prob"
type: "Softmax"
bottom: "fc8a"
top: "prob"
}
-- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua
--
-- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
require 'image'
require 'nn'
require 'nngraph'
require 'models.architectures'
util = paths.dofile('util/util.lua')
options = require 'options'
opt = options.parse_options('test')
-- initialize torch GPU/CPU mode
if opt.gpu > 0 then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu)
print ("GPU Mode")
torch.setdefaulttensortype('torch.CudaTensor')
else
torch.setdefaulttensortype('torch.FloatTensor')
print ("CPU Mode")
end
-- setup visualization
visualizer = require 'util/visualizer'
function TableConcat(t1,t2)
for i=1,#t2 do
t1[#t1+1] = t2[i]
end
return t1
end
-- load data
local data_loader = nil
if opt.align_data > 0 then
require 'data.aligned_data_loader'
data_loader = AlignedDataLoader()
else
require 'data.unaligned_data_loader'
data_loader = UnalignedDataLoader()
end
print( "DataLoader " .. data_loader:name() .. " was created.")
data_loader:Initialize(opt)
if opt.how_many == 'all' then
opt.how_many = data_loader:size()
end
opt.how_many = math.min(opt.how_many, data_loader:size())
-- set batch/instance normalization
set_normalization(opt.norm)
-- load model
opt.continue_train = 1
-- define model
if opt.model == 'cycle_gan' then
require 'models.cycle_gan_model'
model = CycleGANModel()
elseif opt.model == 'one_direction_test' then
require 'models.one_direction_test_model'
model = OneDirectionTestModel()
elseif opt.model == 'pix2pix' then
require 'models.pix2pix_model'
model = Pix2PixModel()
elseif opt.model == 'bigan' then
require 'models.bigan_model'
model = BiGANModel()
elseif opt.model == 'content_gan' then
require 'models.content_gan_model'
model = ContentGANModel()
else
error('Please specify a correct model')
end
model:Initialize(opt)
local pathsA = {} -- paths to images A tested on
local pathsB = {} -- paths to images B tested on
local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase)
paths.mkdir(web_dir)
local image_dir = paths.concat(web_dir, 'images')
paths.mkdir(image_dir)
s1 = opt.fineSize
s2 = opt.fineSize / opt.aspect_ratio
visuals = {}
for n = 1, math.floor(opt.how_many) do
print('processing batch ' .. n)
local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch()
cur_pathsA = util.basename_batch(cur_pathsA)
cur_pathsB = util.basename_batch(cur_pathsB)
print('pathsA', cur_pathsA)
print('pathsB', cur_PathsB)
model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt)
visuals = model:GetCurrentVisuals(opt, opt.fineSize)
for i,visual in ipairs(visuals) do
if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
s1 = nil
s2 = nil
end
visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2)
end
print('Saved images to: ', image_dir)
pathsA = TableConcat(pathsA, cur_pathsA)
pathsB = TableConcat(pathsB, cur_pathsB)
end
labels = {}
for i,visual in ipairs(visuals) do
table.insert(labels, visual.label)
end
-- make webpage
io.output(paths.concat(web_dir, 'index.html'))
io.write('<table style="text-align:center;">')
io.write('<tr><td> Image </td>')
for i = 1, #labels do
io.write('<td>' .. labels[i] .. '</td>')
end
io.write('</tr>')
for n = 1,math.floor(opt.how_many) do
io.write('<tr>')
io.write('<td>' .. tostring(n) .. '</td>')
for j = 1, #labels do
label = labels[j]
io.write('<td><img src="./images/' .. label .. '/' .. string.gsub(pathsA[n],'.jpg','.png') .. '"/></td>')
end
io.write('</tr>')
end
io.write('</table>')
-- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua
-- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
require 'torch'
require 'nn'
require 'optim'
util = paths.dofile('util/util.lua')
content = paths.dofile('util/content_loss.lua')
require 'image'
require 'models.architectures'
-- load configuration file
options = require 'options'
opt = options.parse_options('train')
-- setup visualization
visualizer = require 'util/visualizer'
-- initialize torch GPU/CPU mode
if opt.gpu > 0 then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu)
print ("GPU Mode")
torch.setdefaulttensortype('torch.CudaTensor')
else
torch.setdefaulttensortype('torch.FloatTensor')
print ("CPU Mode")
end
-- load data
local data_loader = nil
if opt.align_data > 0 then
require 'data.aligned_data_loader'
data_loader = AlignedDataLoader()
else
require 'data.unaligned_data_loader'
data_loader = UnalignedDataLoader()
end
print( "DataLoader " .. data_loader:name() .. " was created.")
data_loader:Initialize(opt)
-- set batch/instance normalization
set_normalization(opt.norm)
--- timer
local epoch_tm = torch.Timer()
local tm = torch.Timer()
-- define model
local model = nil
local display_plot = nil
if opt.model == 'cycle_gan' then
assert(data_loader:name() == 'UnalignedDataLoader')
require 'models.cycle_gan_model'
model = CycleGANModel()
elseif opt.model == 'pix2pix' then
require 'models.pix2pix_model'
assert(data_loader:name() == 'AlignedDataLoader')
model = Pix2PixModel()
elseif opt.model == 'bigan' then
assert(data_loader:name() == 'UnalignedDataLoader')
require 'models.bigan_model'
model = BiGANModel()
elseif opt.model == 'content_gan' then
require 'models.content_gan_model'
assert(data_loader:name() == 'UnalignedDataLoader')
model = ContentGANModel()
else
error('Please specify a correct model')
end
-- print the model name
print('Model ' .. model:model_name() .. ' was specified.')
model:Initialize(opt)
-- set up the loss plot
require 'util/plot_util'
plotUtil = PlotUtil()
display_plot = model:DisplayPlot(opt)
plotUtil:Initialize(display_plot, opt.display_id, opt.name)
--------------------------------------------------------------------------------
-- Helper Functions
--------------------------------------------------------------------------------
function visualize_current_results()
local visuals = model:GetCurrentVisuals(opt)
for i,visual in ipairs(visuals) do
visualizer.disp_image(visual.img, opt.display_winsize,
opt.display_id+i, opt.name .. ' ' .. visual.label)
end
end
function save_current_results(epoch, counter)
local visuals = model:GetCurrentVisuals(opt)
for i,visual in ipairs(visuals) do
output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg')
visualizer.save_results(visual.img, output_path)
end
end
function print_current_errors(epoch, counter_in_epoch)
print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
.. '%s'):
format(epoch, ((counter_in_epoch-1) / opt.batchSize),
math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize),
tm:time().real / opt.batchSize,
data_loader:time_elapsed_to_fetch_data() / opt.batchSize,
model:GetCurrentErrorDescription()
))
end
function plot_current_errors(epoch, counter_ratio, opt)
local errs = model:GetCurrentErrors(opt)
local plot_vals = { epoch + counter_ratio}
plotUtil:Display(plot_vals, errs)
end
--------------------------------------------------------------------------------
-- Main Training Loop
--------------------------------------------------------------------------------
local counter = 0
local num_batches = math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize)
print('#training iterations: ' .. opt.niter+opt.niter_decay )
for epoch = 1, opt.niter+opt.niter_decay do
epoch_tm:reset()
for counter_in_epoch = 1, math.min(data_loader:size(), opt.ntrain), opt.batchSize do
tm:reset()
-- load a batch and run G on that batch
local real_dataA, real_dataB, _, _ = data_loader:GetNextBatch()
model:Forward({real_A=real_dataA, real_B=real_dataB}, opt)
-- run forward pass
opt.counter = counter
-- run backward pass
model:OptimizeParameters(opt)
-- display on the web server
if counter % opt.display_freq == 0 and opt.display_id > 0 then
visualize_current_results()
end
-- logging
if counter % opt.print_freq == 0 then
print_current_errors(epoch, counter_in_epoch)
plot_current_errors(epoch, counter_in_epoch/num_batches, opt)
end
-- save latest model
if counter % opt.save_latest_freq == 0 and counter > 0 then
print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter))
model:Save('latest', opt)
end
-- save latest results
if counter % opt.save_display_freq == 0 then
save_current_results(epoch, counter)
end
counter = counter + 1
end
-- save model at the end of epoch
if epoch % opt.save_epoch_freq == 0 then
print(('saving the model (epoch %d, iters %d)'):format(epoch, counter))
model:Save('latest', opt)
model:Save(epoch, opt)
end
-- print the timing information after each epoch
print(('End of epoch %d / %d \t Time Taken: %.3f'):
format(epoch, opt.niter+opt.niter_decay, epoch_tm:time().real))
-- update learning rate
if epoch > opt.niter then
model:UpdateLearningRate(opt)
end
-- refresh parameters
model:RefreshParameters(opt)
end
require 'nn'
--[[
Implements instance normalization as described in the paper
Instance Normalization: The Missing Ingredient for Fast Stylization
Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
https://arxiv.org/abs/1607.08022
This implementation is based on
https://github.com/DmitryUlyanov/texture_nets
]]
local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module')
function InstanceNormalization:__init(nOutput, eps, momentum, affine)
parent.__init(self)
self.running_mean = torch.zeros(nOutput)
self.running_var = torch.ones(nOutput)
self.eps = eps or 1e-5
self.momentum = momentum or 0.0
if affine ~= nil then
assert(type(affine) == 'boolean', 'affine has to be true/false')
self.affine = affine
else
self.affine = true
end
self.nOutput = nOutput
self.prev_batch_size = -1
if self.affine then
self.weight = torch.Tensor(nOutput):uniform()
self.bias = torch.Tensor(nOutput):zero()
self.gradWeight = torch.Tensor(nOutput)
self.gradBias = torch.Tensor(nOutput)
end
end
function InstanceNormalization:updateOutput(input)
self.output = self.output or input.new()
assert(input:size(2) == self.nOutput)
local batch_size = input:size(1)
if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type()) then
self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine)
self.bn:type(self:type())
self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size))
self.bn.running_var:copy(self.running_var:repeatTensor(batch_size))
self.prev_batch_size = input:size(1)
end
-- Get statistics
self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1))
self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1))
-- Set params for BN
if self.affine then
self.bn.weight:copy(self.weight:repeatTensor(batch_size))
self.bn.bias:copy(self.bias:repeatTensor(batch_size))
end
local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
self.output = self.bn:forward(input_1obj):viewAs(input)
return self.output
end
function InstanceNormalization:updateGradInput(input, gradOutput)
self.gradInput = self.gradInput or gradOutput.new()
assert(self.bn)
local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
if self.affine then
self.bn.gradWeight:zero()
self.bn.gradBias:zero()
end
self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input)
if self.affine then
self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1))
self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1))
end
return self.gradInput
end
function InstanceNormalization:clearState()
self.output = self.output.new()
self.gradInput = self.gradInput.new()
if self.bn then
self.bn:clearState()
end
end
function InstanceNormalization:evaluate()
end
function InstanceNormalization:training()
end
-- define nn module for VGG postprocessing
local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module')
function VGG_postprocess:__init()
parent.__init(self)
end
function VGG_postprocess:updateOutput(input)
self.output = input:add(1):mul(127.5)
-- print(self.output:max(), self.output:min())
if self.output:max() > 255 or self.output:min() < 0 then
print(self.output:min(), self.output:max())
end
-- assert(self.output:min()>=0,"badly scaled inputs")
-- assert(self.output:max()<=255,"badly scaled inputs")
local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
mean_pixel = mean_pixel:reshape(1,3,1,1)
mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda()
self.output:add(-1, mean_pixel)
return self.output
end
function VGG_postprocess:updateGradInput(input, gradOutput)
self.gradInput = gradOutput:div(127.5)
return self.gradInput
end
require 'torch'
require 'nn'
local content = {}
function content.defineVGG(content_layer)
local contentFunc = nn.Sequential()
require 'loadcaffe'
require 'util/VGG_preprocess'
cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn')
contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
contentFunc:add(nn.VGG_postprocess())
for i = 1, #cnn do
local layer = cnn:get(i):clone()
local name = layer.name
local layer_type = torch.type(layer)
contentFunc:add(layer)
if name == content_layer then
print("Setting up content layer: ", layer.name)
break
end
end
cnn = nil
collectgarbage()
print(contentFunc)
return contentFunc
end
function content.defineAlexNet(content_layer)
local contentFunc = nn.Sequential()
require 'loadcaffe'
require 'util/VGG_preprocess'
cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn')
contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
contentFunc:add(nn.VGG_postprocess())
for i = 1, #cnn do
local layer = cnn:get(i):clone()
local name = layer.name
local layer_type = torch.type(layer)
contentFunc:add(layer)
if name == content_layer then
print("Setting up content layer: ", layer.name)
break
end
end
cnn = nil
collectgarbage()
print(contentFunc)
return contentFunc
end
function content.defineContent(content_loss, layer_name)
-- print('content_loss_define', content_loss)
if content_loss == 'pixel' or content_loss == 'none' then
return nil
elseif content_loss == 'vgg' then
return content.defineVGG(layer_name)
else
print("unsupported content loss")
return nil
end
end
function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
if loss_type == 'none' then
local errCont = 0.0
local df_d_content = torch.zeros(fake_target:size())
return errCont, df_d_content
elseif loss_type == 'pixel' then
local errCont = criterionContent:forward(fake_target, real_source) * weight
local df_do_content = criterionContent:backward(fake_target, real_source)*weight
return errCont, df_do_content
elseif loss_type == 'vgg' then
local f_fake = contentFunc:forward(fake_target):clone()
local f_real = contentFunc:forward(real_source):clone()
local errCont = criterionContent:forward(f_fake, f_real) * weight
local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight
local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight)
return errCont, df_do_content
else error("unsupported content loss")
end
end
return content
-- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua
-- removed error on nngraph
-- modules that can be converted to nn seamlessly
local layer_list = {
'BatchNormalization',
'SpatialBatchNormalization',
'SpatialConvolution',
'SpatialCrossMapLRN',
'SpatialFullConvolution',
'SpatialMaxPooling',
'SpatialAveragePooling',
'ReLU',
'Tanh',
'Sigmoid',
'SoftMax',
'LogSoftMax',
'VolumetricBatchNormalization',
'VolumetricConvolution',
'VolumetricFullConvolution',
'VolumetricMaxPooling',
'VolumetricAveragePooling',
}
-- goes over a given net and converts all layers to dst backend
-- for example: net = cudnn_convert_custom(net, cudnn)
-- same as cudnn.convert with gModule check commented out
function cudnn_convert_custom(net, dst, exclusion_fn)
return net:replace(function(x)
--if torch.type(x) == 'nn.gModule' then
-- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule')
-- return x
--end
local y = 0
local src = dst == nn and cudnn or nn
local src_prefix = src == nn and 'nn.' or 'cudnn.'
local dst_prefix = dst == nn and 'nn.' or 'cudnn.'
local function convert(v)
local y = {}
torch.setmetatable(y, dst_prefix..v)
if v == 'ReLU' then y = dst.ReLU() end -- because parameters
for k,u in pairs(x) do y[k] = u end
if src == cudnn and x.clearDesc then x.clearDesc(y) end
if src == cudnn and v == 'SpatialAveragePooling' then
y.divide = true
y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
end
if src == nn and string.find(v, 'Convolution') then
y.groups = 1
end
return y
end
if exclusion_fn and exclusion_fn(x) then
return x
end
local t = torch.typename(x)
if t == 'nn.SpatialConvolutionMM' then
y = convert('SpatialConvolution')
elseif t == 'inn.SpatialCrossResponseNormalization' then
y = convert('SpatialCrossMapLRN')
else
for i,v in ipairs(layer_list) do
if torch.typename(x) == src_prefix..v then
y = convert(v)
end
end
end
return y == 0 and x or y
end)
end
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment