one_direction_test_model.lua 2.12 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'

util = paths.dofile('../util/util.lua')
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')

function OneDirectionTestModel:__init(conf)
  BaseModel.__init(self, conf)
  conf = conf or {}
end

function OneDirectionTestModel:model_name()
  return 'OneDirectionTestModel'
end

-- Defines models and networks
function OneDirectionTestModel:Initialize(opt)
  -- define tensors
  self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)

  -- load/define models
  self.netG_A = util.load_test_model('G', opt)

  -- setup optnet to save a bit of memory
  if opt.use_optnet == 1 then
    local optnet = require 'optnet'
    local sample_input = torch.randn(1, opt.input_nc, 2, 2)
    optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
  end

  self:RefreshParameters()

  print('---------- # Learnable Parameters --------------')
  print(('G_A = %d'):format(self.parametersG_A:size(1)))
  print('------------------------------------------------')
end

-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function OneDirectionTestModel:Forward(input, opt)
  if opt.which_direction == 'BtoA' then
  	input.real_A = input.real_B:clone()
  end

  self.real_A = input.real_A:clone()
  if opt.gpu > 0 then
    self.real_A = self.real_A:cuda()
  end

  self.fake_B = self.netG_A:forward(self.real_A):clone()
end

function OneDirectionTestModel:RefreshParameters()
  self.parametersG_A, self.gradparametersG_A = nil, nil
  self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
end


local function MakeIm3(im)
  if im:size(2) == 1 then
    local im3 = torch.repeatTensor(im, 1,3,1,1)
    return im3
  else
    return im
  end
end

function OneDirectionTestModel:GetCurrentVisuals(opt, size)
  if not size then
    size = opt.display_winsize
  end

  local visuals = {}
  table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
  table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
  return visuals
end