unaligned_data_loader.lua 1.47 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are not aligned.
-- The datasets can have different sizes
--------------------------------------------------------------------------------
require 'data.base_data_loader'

local class = require 'class'
data_util = paths.dofile('data_util.lua')

UnalignedDataLoader = class('UnalignedDataLoader', 'BaseDataLoader')

function UnalignedDataLoader:__init(conf)
  BaseDataLoader.__init(self, conf)
  conf = conf or {}
end

function UnalignedDataLoader:name()
  return 'UnalignedDataLoader'
end

function UnalignedDataLoader:Initialize(opt)
  opt.align_data = 0
  self.dataA = data_util.load_dataset('A', opt, opt.input_nc)
  self.dataB = data_util.load_dataset('B', opt, opt.output_nc)
end

-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function UnalignedDataLoader:LoadBatchForAllDatasets()
  local batchA, pathA = self.dataA:getBatch()
  local batchB, pathB = self.dataB:getBatch()
  return batchA, batchB, pathA, pathB
end

-- returns the size of each dataset
function UnalignedDataLoader:size(dataset)
  if dataset == 'A' then
    return self.dataA:size()
  end

  if dataset == 'B' then
    return self.dataB:size()
  end

  return math.max(self.dataA:size(), self.dataB:size())
  -- return the size of the largest dataset by default
end