data.lua 2.63 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

-- ModelNet-40 data - https://github.com/charlesq34/3dcnn.torch
-- input is a list of active coordinates in a box [0,29]^3

tnt=require 'torchnet'
scn=require 'sparseconvnet'
require 'paths'

if not paths.dirp('t7') then
  print('Downloading and preprocessing data ...')
  dofile('process.lua')
end

function train(spatialSize,precomputeStride)
  local d={}
  for x=1,590580 do
    d[x]=x
  end
  d=tnt.ListDataset(d,function(x) return torch.load('t7/train/'..x..'.t7') end):shuffle()
  function d:manualSeed(seed) torch.manualSeed(seed) end
  d=tnt.BatchDataset(d,100,function(idx, size) return idx end,function (tbl)
      input=scn.InputBatch(3,spatialSize)
      local offset = spatialSize/2-15
      local v=torch.FloatTensor({1})
      for _,obj in ipairs(tbl.input) do
        input:addSample()
        obj=obj:type('torch.LongTensor'):add((offset+torch.LongTensor(3):random(-2,2)):view(1,3):expandAs(obj))
        for i=1,obj:size(1) do
          local p = obj[i]
          input:setLocation(obj[i],v,0)
        end
      end
      input:precomputeMetadata(precomputeStride)
      return {input=input,target=torch.LongTensor(tbl.target)}
    end
  )
  d=tnt.ParallelDatasetIterator({
      init = function() require 'torchnet'; scn=require 'sparseconvnet' end,
      nthread = 10,
      closure = function() return d end,
      ordered = true})
  return function(epoch)
    d:exec('manualSeed', epoch)
    d:exec('resample')
    return d()
  end
end
function val(spatialSize,precomputeStride)
  local d={}
  for x=1,148080 do
    d[x]=x
  end
  d=tnt.ListDataset(d,function(x) return torch.load('t7/test/'..x..'.t7') end)
  d=tnt.BatchDataset(d,100,function(idx, size) return idx end,function (tbl)
      input=scn.InputBatch(3,spatialSize)
      local v=torch.FloatTensor({1})
      local offset = (spatialSize/2-15):view(1,3)
      for _,obj in ipairs(tbl.input) do
        input:addSample()
        obj=obj:type('torch.LongTensor'):add(offset:view(1,3):expandAs(obj))
        for i=1,obj:size(1) do
          input:setLocation(obj[i],v,0)
        end
      end
      input:precomputeMetadata(precomputeStride)
      return {input=input,target=torch.LongTensor(tbl.target)}
    end
  )
  d=tnt.ParallelDatasetIterator({
      init = function() require 'torchnet'; scn=require 'sparseconvnet' end,
      nthread = 10,
      closure = function() return d end,
      ordered = true})
  return function()
    return d()
  end
end

return function(...)
  return {train=train(...), val=val(...)}
end