"mmdet3d/vscode:/vscode.git/clone" did not exist on "d99dbce7005be2950404ef7e7811061d88c75453"
Commit 05f0839a authored by dengjb's avatar dengjb
Browse files

update

parents
local class = require 'class'
ImagePool= class('ImagePool')
require 'torch'
require 'image'
function ImagePool:__init(pool_size)
self.pool_size = pool_size
if pool_size > 0 then
self.num_imgs = 0
self.images = {}
end
end
function ImagePool:model_name()
return 'ImagePool'
end
--
-- function ImagePool:Initialize(pool_size)
-- -- torch.manualSeed(0)
-- -- assert(pool_size > 0)
-- self.pool_size = pool_size
-- if pool_size > 0 then
-- self.num_imgs = 0
-- self.images = {}
-- end
-- end
function ImagePool:Query(image)
-- print('query image')
if self.pool_size == 0 then
-- print('get identical image')
return image
end
if self.num_imgs < self.pool_size then
-- self.images.insert(image:clone())
self.num_imgs = self.num_imgs + 1
self.images[self.num_imgs] = image
return image
else
local p = math.random()
-- print('p' ,p)
-- os.exit()
if p > 0.5 then
-- print('use old image')
-- random_id = torch.Tensor(1)
-- random_id:random(1, self.pool_size)
local random_id = math.random(self.pool_size)
-- print('random_id', random_id)
local tmp = self.images[random_id]:clone()
self.images[random_id] = image:clone()
return tmp
else
return image
end
end
end
local class = require 'class'
PlotUtil = class('PlotUtil')
require 'torch'
disp = require 'display'
util = require 'util/util'
require 'image'
local unpack = unpack or table.unpack
function PlotUtil:__init(conf)
conf = conf or {}
end
function PlotUtil:model_name()
return 'PlotUtil'
end
function PlotUtil:Initialize(display_plot, display_id, name)
self.display_plot = string.split(string.gsub(display_plot, "%s+", ""), ",")
self.plot_config = {
title = name .. ' loss over time',
labels = {'epoch', unpack(self.display_plot)},
ylabel = 'loss',
win = display_id,
}
self.plot_data = {}
print('display_opt', self.display_plot)
end
function PlotUtil:Display(plot_vals, loss)
for k, v in ipairs(self.display_plot) do
if loss[v] ~= nil then
plot_vals[#plot_vals + 1] = loss[v]
end
end
table.insert(self.plot_data, plot_vals)
disp.plot(self.plot_data, self.plot_config)
end
--
-- code derived from https://github.com/soumith/dcgan.torch
--
local util = {}
require 'torch'
function util.BiasZero(net)
net:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
end
function util.checkEqual(A, B, name)
local dif = (A:float()-B:float()):abs():mean()
print(name, dif)
end
function util.containsValue(table, value)
for k, v in pairs(table) do
if v == value then return true end
end
return false
end
function util.CheckTensor(A, name)
print(name, A:min(), A:max(), A:mean())
end
function util.normalize(img)
-- rescale image to 0 .. 1
local min = img:min()
local max = img:max()
img = torch.FloatTensor(img:size()):copy(img)
img:add(-min):mul(1/(max-min))
return img
end
function util.normalizeBatch(batch)
for i = 1, batch:size(1) do
batch[i] = util.normalize(batch[i]:squeeze())
end
return batch
end
function util.basename_batch(batch)
for i = 1, #batch do
batch[i] = paths.basename(batch[i])
end
return batch
end
-- default preprocessing
--
-- Preprocesses an image before passing it to a net
-- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
function util.preprocess(img)
-- RGB to BGR
if img:size(1) == 3 then
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm)
end
-- [0,1] to [-1,1]
img = img:mul(2):add(-1)
-- check that input is in expected range
assert(img:max()<=1,"badly scaled inputs")
assert(img:min()>=-1,"badly scaled inputs")
return img
end
-- Undo the above preprocessing.
function util.deprocess(img)
-- BGR to RGB
if img:size(1) == 3 then
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm)
end
-- [-1,1] to [0,1]
img = img:add(1):div(2)
return img
end
function util.preprocess_batch(batch)
for i = 1, batch:size(1) do
batch[i] = util.preprocess(batch[i]:squeeze())
end
return batch
end
function util.print_tensor(name, x)
print(name, x:size(), x:min(), x:mean(), x:max())
end
function util.deprocess_batch(batch)
for i = 1, batch:size(1) do
batch[i] = util.deprocess(batch[i]:squeeze())
end
return batch
end
function util.scaleBatch(batch,s1,s2)
-- print('s1', s1)
-- print('s2', s2)
local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2)
for i = 1, batch:size(1) do
scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze()
end
return scaled_batch
end
function util.toTrivialBatch(input)
return input:reshape(1,input:size(1),input:size(2),input:size(3))
end
function util.fromTrivialBatch(input)
return input[1]
end
-- input is between -1 and 1
function util.jitter(input)
local noise = torch.rand(input:size())/256.0
input:add(1.0):mul(0.5*255.0/256.0):add(noise):add(-0.5):mul(2.0)
--local scaled = (input+1.0)*0.5
--local jittered = scaled*255.0/256.0 + torch.rand(input:size())/256.0
--local scaled_back = (jittered-0.5)*2.0
--return scaled_back
end
function util.scaleImage(input, loadSize)
-- replicate bw images to 3 channels
if input:size(1)==1 then
input = torch.repeatTensor(input,3,1,1)
end
input = image.scale(input, loadSize, loadSize)
return input
end
function util.getAspectRatio(path)
local input = image.load(path, 3, 'float')
local ar = input:size(3)/input:size(2)
return ar
end
function util.loadImage(path, loadSize, nc)
local input = image.load(path, 3, 'float')
input= util.preprocess(util.scaleImage(input, loadSize))
if nc == 1 then
input = input[{{1}, {}, {}}]
end
return input
end
function file_exists(filename)
local f = io.open(filename,"r")
if f ~= nil then io.close(f) return true else return false end
end
-- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations
function load_helper(filename, opt)
fileExists = file_exists(filename)
if not fileExists then
print('model not found! ' .. filename)
return nil
end
print(('loading previously trained model (%s)'):format(filename))
if opt.norm == 'instance' then
print('use InstanceNormalization')
require 'util.InstanceNormalization'
end
if opt.cudnn>0 then
require 'cudnn'
end
local net = torch.load(filename)
if opt.gpu > 0 then
require 'cunn'
net:cuda()
-- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below
if net.forwardnodes then
for i=1,#net.forwardnodes do
if net.forwardnodes[i].data.module then
net.forwardnodes[i].data.module:cuda()
end
end
end
else
net:float()
end
net:apply(function(m) if m.weight then
m.gradWeight = m.weight:clone():zero();
m.gradBias = m.bias:clone():zero(); end end)
return net
end
function util.load_model(name, opt)
-- if opt['lambda_'.. name] > 0.0 then
-- print('not loading model '.. opt.checkpoints_dir .. opt.name ..
-- 'latest_net_' .. name .. '.t7' .. ' because opt.lambda is not greater than zero')
return load_helper(paths.concat(opt.checkpoints_dir, opt.name,
'latest_net_' .. name .. '.t7'), opt)
-- end
end
function util.load_test_model(name, opt)
return load_helper(paths.concat(opt.checkpoints_dir, opt.name,
opt.which_epoch .. '_net_' .. name .. '.t7'), opt)
end
-- load dataset from the file system
-- |name|: name of the dataset. It's currently either 'A' or 'B'
-- function util.load_dataset(name, nc, opt, nc)
-- local tensortype = torch.getdefaulttensortype()
-- torch.setdefaulttensortype('torch.FloatTensor')
--
-- local new_opt = options.clone(opt)
-- new_opt.manualSeed = torch.random(1, 10000) -- fix seed
-- new_opt.nc = nc
-- torch.manualSeed(new_opt.manualSeed)
-- local data_loader = paths.dofile('../data/data.lua')
-- new_opt.phase = new_opt.phase .. name
-- local data = data_loader.new(new_opt.nThreads, new_opt)
-- print("Dataset Size " .. name .. ": ", data:size())
--
-- torch.setdefaulttensortype(tensortype)
-- return data
-- end
function util.cudnn(net)
require 'cudnn'
require 'util/cudnn_convert_custom'
return cudnn_convert_custom(net, cudnn)
end
function util.save_model(net, net_name, weight)
if weight > 0.0 then
torch.save(paths.concat(opt.checkpoints_dir, opt.name, net_name), net:clearState())
end
end
return util
-------------------------------------------------------------
-- Various utilities for visualization through the web server
-------------------------------------------------------------
local visualizer = {}
require 'torch'
disp = nil
print(opt)
if opt.display_id > 0 then -- [hack]: assume that opt already existed
disp = require 'display'
end
util = require 'util/util'
require 'image'
-- function visualizer
function visualizer.disp_image(img_data, win_size, display_id, title)
images = util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size))
disp.image(images, {win=display_id, title=title})
end
function visualizer.save_results(img_data, output_path)
local tensortype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
local image_out = nil
local win_size = opt.display_winsize
images = torch.squeeze(util.deprocess_batch(util.scaleBatch(img_data:float(), win_size, win_size)))
if images:dim() == 3 then
image_out = images
else
for i = 1,images:size(1) do
img = images[i]
if image_out == nil then
image_out = img
else
image_out = torch.cat(image_out, img)
end
end
end
image.save(output_path, image_out)
torch.setdefaulttensortype(tensortype)
end
function visualizer.save_images(imgs, save_dir, impaths, s1, s2)
local tensortype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
batchSize = imgs:size(1)
imgs_f = util.deprocess_batch(imgs):float()
paths.mkdir(save_dir)
for i = 1, batchSize do -- imgs_f[i]:size(2), imgs_f[i]:size(3)/opt.aspect_ratio
if s1 ~= nil and s2 ~= nil then
im_s = image.scale(imgs_f[i], s1, s2):float()
else
im_s = imgs_f[i]:float()
end
img_to_save = torch.FloatTensor(im_s:size()):copy(im_s)
image.save(paths.concat(save_dir, impaths[i]), img_to_save)
end
torch.setdefaulttensortype(tensortype)
end
return visualizer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment