--[[ Copyright (c) 2015-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- require 'torch' torch.setdefaulttensortype('torch.FloatTensor') local ffi = require 'ffi' local class = require('pl.class') local dir = require 'pl.dir' local tablex = require 'pl.tablex' local argcheck = require 'argcheck' require 'sys' require 'xlua' require 'image' local dataset = torch.class('dataLoader') local initcheck = argcheck{ pack=true, help=[[ A dataset class for images in a flat folder structure (folder-name is class-name). Optimized for extremely large datasets (upwards of 14 million images). Tested only on Linux (as it uses command-line linux utilities to scale up) ]], {check=function(paths) local out = true; for k,v in ipairs(paths) do if type(v) ~= 'string' then print('paths can only be of string input'); out = false end end return out end, name="paths", type="table", help="Multiple paths of directories with images"}, {name="sampleSize", type="table", help="a consistent sample size to resize the images"}, {name="split", type="number", help="Percentage of split to go to Training" }, {name="serial_batches", type="number", help="if randomly sample training images"}, {name="samplingMode", type="string", help="Sampling mode: random | balanced ", default = "balanced"}, {name="verbose", type="boolean", help="Verbose mode during initialization", default = false}, {name="loadSize", type="table", help="a size to load the images to, initially", opt = true}, {name="forceClasses", type="table", help="If you want this loader to map certain classes to certain indices, " .. "pass a classes table that has {classname : classindex} pairs." .. " For example: {3 : 'dog', 5 : 'cat'}" .. "This function is very useful when you want two loaders to have the same " .. "class indices (trainLoader/testLoader for example)", opt = true}, {name="sampleHookTrain", type="function", help="applied to sample during training(ex: for lighting jitter). " .. "It takes the image path as input", opt = true}, {name="sampleHookTest", type="function", help="applied to sample during testing", opt = true}, } function dataset:__init(...) -- argcheck local args = initcheck(...) print(args) for k,v in pairs(args) do self[k] = v end if not self.loadSize then self.loadSize = self.sampleSize; end if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end self.image_count = 1 -- print('image_count_init', self.image_count) -- find class names self.classes = {} local classPaths = {} if self.forceClasses then for k,v in pairs(self.forceClasses) do self.classes[k] = v classPaths[k] = {} end end local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end -- loop over each paths folder, get list of unique class names, -- also store the directory paths per class -- for each class, for k,path in ipairs(self.paths) do -- print('path', path) local dirs = {} -- hack dirs[1] = path -- local dirs = dir.getdirectories(path); for k,dirpath in ipairs(dirs) do local class = paths.basename(dirpath) local idx = tableFind(self.classes, class) -- print(class) -- print(idx) if not idx then table.insert(self.classes, class) idx = #self.classes classPaths[idx] = {} end if not tableFind(classPaths[idx], dirpath) then table.insert(classPaths[idx], dirpath); end end end self.classIndices = {} for k,v in ipairs(self.classes) do self.classIndices[v] = k end -- define command-line tools, try your best to maintain OSX compatibility local wc = 'wc' local cut = 'cut' local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing if ffi.os == 'OSX' then wc = 'gwc' cut = 'gcut' find = 'gfind' end ---------------------------------------------------------------------- -- Options for the GNU find command local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} local findOptions = ' -iname "*.' .. extensionList[1] .. '"' for i=2,#extensionList do findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' end -- find the image path names self.imagePath = torch.CharTensor() -- path to each image in dataset self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) self.classList = {} -- index of imageList to each image of a particular class self.classListSample = self.classList -- the main list used when sampling data print('running "find" on each class directory, and concatenate all' .. ' those filenames into a single file containing all image paths for a given class') -- so, generates one file per class local classFindFiles = {} for i=1,#self.classes do classFindFiles[i] = os.tmpname() end local combinedFindList = os.tmpname(); local tmpfile = os.tmpname() local tmphandle = assert(io.open(tmpfile, 'w')) -- iterate over classes for i, class in ipairs(self.classes) do -- iterate over classPaths for j,path in ipairs(classPaths[i]) do local command = find .. ' "' .. path .. '" ' .. findOptions .. ' >>"' .. classFindFiles[i] .. '" \n' tmphandle:write(command) end end io.close(tmphandle) os.execute('bash ' .. tmpfile) os.execute('rm -f ' .. tmpfile) print('now combine all the files to a single large file') local tmpfile = os.tmpname() local tmphandle = assert(io.open(tmpfile, 'w')) -- concat all finds to a single large file in the order of self.classes for i=1,#self.classes do local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n' tmphandle:write(command) end io.close(tmphandle) os.execute('bash ' .. tmpfile) os.execute('rm -f ' .. tmpfile) --========================================================================== print('load the large concatenated list of sample paths to self.imagePath') local cmd = wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '" print('cmd..' .. cmd) local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) + 1 local length = tonumber(sys.fexecute(wc .. " -l '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) assert(length > 0, "Could not find any image file in the given input paths") assert(maxPathLength > 0, "paths of files are length 0?") self.imagePath:resize(length, maxPathLength):fill(0) local s_data = self.imagePath:data() local count = 0 for line in io.lines(combinedFindList) do ffi.copy(s_data, line) s_data = s_data + maxPathLength if self.verbose and count % 10000 == 0 then xlua.progress(count, length) end; count = count + 1 end self.numSamples = self.imagePath:size(1) if self.verbose then print(self.numSamples .. ' samples found.') end --========================================================================== print('Updating classList and imageClass appropriately') self.imageClass:resize(self.numSamples) local runningIndex = 0 for i=1,#self.classes do if self.verbose then xlua.progress(i, #(self.classes)) end local length = tonumber(sys.fexecute(wc .. " -l '" .. classFindFiles[i] .. "' |" .. cut .. " -f1 -d' '")) if length == 0 then error('Class has zero samples') else self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long() self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i) end runningIndex = runningIndex + length end --========================================================================== -- clean up temporary files print('Cleaning up temporary files') local tmpfilelistall = '' for i=1,#(classFindFiles) do tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"' if i % 1000 == 0 then os.execute('rm -f ' .. tmpfilelistall) tmpfilelistall = '' end end os.execute('rm -f ' .. tmpfilelistall) os.execute('rm -f "' .. combinedFindList .. '"') --========================================================================== if self.split == 100 then self.testIndicesSize = 0 else print('Splitting training and test sets to a ratio of ' .. self.split .. '/' .. (100-self.split)) self.classListTrain = {} self.classListTest = {} self.classListSample = self.classListTrain local totalTestSamples = 0 -- split the classList into classListTrain and classListTest for i=1,#self.classes do local list = self.classList[i] local count = self.classList[i]:size(1) local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round local perm = torch.randperm(count) self.classListTrain[i] = torch.LongTensor(splitidx) for j=1,splitidx do self.classListTrain[i][j] = list[perm[j]] end if splitidx == count then -- all samples were allocated to train set self.classListTest[i] = torch.LongTensor() else self.classListTest[i] = torch.LongTensor(count-splitidx) totalTestSamples = totalTestSamples + self.classListTest[i]:size(1) local idx = 1 for j=splitidx+1,count do self.classListTest[i][idx] = list[perm[j]] idx = idx + 1 end end end -- Now combine classListTest into a single tensor self.testIndices = torch.LongTensor(totalTestSamples) self.testIndicesSize = totalTestSamples local tdata = self.testIndices:data() local tidx = 0 for i=1,#self.classes do local list = self.classListTest[i] if list:dim() ~= 0 then local ldata = list:data() for j=0,list:size(1)-1 do tdata[tidx] = ldata[j] tidx = tidx + 1 end end end end end -- size(), size(class) function dataset:size(class, list) list = list or self.classList if not class then return self.numSamples elseif type(class) == 'string' then return list[self.classIndices[class]]:size(1) elseif type(class) == 'number' then return list[class]:size(1) end end -- getByClass function dataset:getByClass(class) local index = 0 if self.serial_batches == 1 then index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1 self.image_count = self.image_count +1 else index = math.ceil(torch.uniform() * self.classListSample[class]:nElement()) end local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]])) return self:sampleHookTrain(imgpath), imgpath end -- converts a table of samples (and corresponding labels) to a clean tensor local function tableToOutput(self, dataTable, scalarTable) local data, scalarLabels, labels if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then assert(#scalarTable == 1) data = torch.Tensor(1, dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3)) data[1]:copy(dataTable[1]) scalarLabels = torch.LongTensor(#scalarTable):fill(-1111) else local quantity = #scalarTable data = torch.Tensor(quantity, self.sampleSize[1], self.sampleSize[2], self.sampleSize[3]) scalarLabels = torch.LongTensor(quantity):fill(-1111) for i=1,#dataTable do data[i]:copy(dataTable[i]) scalarLabels[i] = scalarTable[i] end end return data, scalarLabels end -- sampler, samples from the training set. function dataset:sample(quantity) assert(quantity) local dataTable = {} local scalarTable = {} local samplePaths = {} for i=1,quantity do local class = torch.random(1, #self.classes) local out, imgpath = self:getByClass(class) table.insert(dataTable, out) table.insert(scalarTable, class) samplePaths[i] = imgpath end local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) return data, scalarLabels, samplePaths-- filePaths end function dataset:get(i1, i2) local indices = torch.range(i1, i2); local quantity = i2 - i1 + 1; assert(quantity > 0) -- now that indices has been initialized, get the samples local dataTable = {} local scalarTable = {} for i=1,quantity do -- load the sample local imgpath = ffi.string(torch.data(self.imagePath[indices[i]])) local out = self:sampleHookTest(imgpath) table.insert(dataTable, out) table.insert(scalarTable, self.imageClass[indices[i]]) end local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) return data, scalarLabels end return dataset