-- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua -- -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix require 'image' require 'nn' require 'nngraph' require 'models.architectures' util = paths.dofile('util/util.lua') options = require 'options' opt = options.parse_options('test') -- initialize torch GPU/CPU mode if opt.gpu > 0 then require 'cutorch' require 'cunn' cutorch.setDevice(opt.gpu) print ("GPU Mode") torch.setdefaulttensortype('torch.CudaTensor') else torch.setdefaulttensortype('torch.FloatTensor') print ("CPU Mode") end -- setup visualization visualizer = require 'util/visualizer' function TableConcat(t1,t2) for i=1,#t2 do t1[#t1+1] = t2[i] end return t1 end -- load data local data_loader = nil if opt.align_data > 0 then require 'data.aligned_data_loader' data_loader = AlignedDataLoader() else require 'data.unaligned_data_loader' data_loader = UnalignedDataLoader() end print( "DataLoader " .. data_loader:name() .. " was created.") data_loader:Initialize(opt) if opt.how_many == 'all' then opt.how_many = data_loader:size() end opt.how_many = math.min(opt.how_many, data_loader:size()) -- set batch/instance normalization set_normalization(opt.norm) -- load model opt.continue_train = 1 -- define model if opt.model == 'cycle_gan' then require 'models.cycle_gan_model' model = CycleGANModel() elseif opt.model == 'one_direction_test' then require 'models.one_direction_test_model' model = OneDirectionTestModel() elseif opt.model == 'pix2pix' then require 'models.pix2pix_model' model = Pix2PixModel() elseif opt.model == 'bigan' then require 'models.bigan_model' model = BiGANModel() elseif opt.model == 'content_gan' then require 'models.content_gan_model' model = ContentGANModel() else error('Please specify a correct model') end model:Initialize(opt) local pathsA = {} -- paths to images A tested on local pathsB = {} -- paths to images B tested on local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase) paths.mkdir(web_dir) local image_dir = paths.concat(web_dir, 'images') paths.mkdir(image_dir) s1 = opt.fineSize s2 = opt.fineSize / opt.aspect_ratio visuals = {} for n = 1, math.floor(opt.how_many) do print('processing batch ' .. n) local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch() cur_pathsA = util.basename_batch(cur_pathsA) cur_pathsB = util.basename_batch(cur_pathsB) print('pathsA', cur_pathsA) print('pathsB', cur_PathsB) model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt) visuals = model:GetCurrentVisuals(opt, opt.fineSize) for i,visual in ipairs(visuals) do if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then s1 = nil s2 = nil end visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2) end print('Saved images to: ', image_dir) pathsA = TableConcat(pathsA, cur_pathsA) pathsB = TableConcat(pathsB, cur_pathsB) end labels = {} for i,visual in ipairs(visuals) do table.insert(labels, visual.label) end -- make webpage io.output(paths.concat(web_dir, 'index.html')) io.write('') io.write('') for i = 1, #labels do io.write('') end io.write('') for n = 1,math.floor(opt.how_many) do io.write('') io.write('') for j = 1, #labels do label = labels[j] io.write('') end io.write('') end io.write('
Image ' .. labels[i] .. '
' .. tostring(n) .. '
')