aligned_data_loader.lua 1.37 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are aligned
-- The datasets are of the same size
--------------------------------------------------------------------------------
require 'data.base_data_loader'

local class = require 'class'
data_util = paths.dofile('data_util.lua')

AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')

function AlignedDataLoader:__init(conf)
  BaseDataLoader.__init(self, conf)
  conf = conf or {}
end

function AlignedDataLoader:name()
  return 'AlignedDataLoader'
end

function AlignedDataLoader:Initialize(opt)
  opt.align_data = 1
  self.idx_A = {1, opt.input_nc}
  self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
  local nc = 3--opt.input_nc + opt.output_nc
  self.data = data_util.load_dataset('', opt, nc)
end

-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function AlignedDataLoader:LoadBatchForAllDatasets()
  local batch_data, path = self.data:getBatch()
  local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
  local batchB = batch_data[{ {}, self.idx_B, {}, {} }]

  return batchA, batchB, path, path
end

-- returns the size of each dataset
function AlignedDataLoader:size(dataset)
  return self.data:size()
end