image_pool.lua 1.28 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
local class = require 'class'
ImagePool= class('ImagePool')

require 'torch'
require 'image'

function ImagePool:__init(pool_size)
  self.pool_size = pool_size
  if pool_size > 0 then
    self.num_imgs = 0
    self.images = {}
  end
end

function ImagePool:model_name()
  return 'ImagePool'
end
-- 
-- function ImagePool:Initialize(pool_size)
--   -- torch.manualSeed(0)
--   -- assert(pool_size > 0)
--   self.pool_size = pool_size
--   if pool_size > 0 then
--     self.num_imgs = 0
--     self.images = {}
--   end
-- end

function ImagePool:Query(image)
  -- print('query image')
  if self.pool_size == 0 then
    -- print('get identical image')
    return image
  end
  if self.num_imgs < self.pool_size then
    -- self.images.insert(image:clone())
    self.num_imgs = self.num_imgs + 1
    self.images[self.num_imgs] = image
    return image
  else
    local p = math.random()
    -- print('p' ,p)
    -- os.exit()
    if p > 0.5 then
      -- print('use old image')
      -- random_id = torch.Tensor(1)
      -- random_id:random(1, self.pool_size)
      local random_id = math.random(self.pool_size)
      -- print('random_id', random_id)
      local tmp = self.images[random_id]:clone()
      self.images[random_id] = image:clone()
      return tmp
    else
      return image
    end

  end

end