data.lua 4.13 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

tnt=require 'torchnet'
scn=require 'sparseconvnet'
require 'paths'

if not paths.dirp('t7/') then
  print('Downloading and preprocessing data ...')
  os.execute('bash process.sh')
  dofile('process.lua')
end

function train(spatialSize,Scale,precomputeStride)
  local d=torch.load('t7/train.t7')
  print('Replicating training set 10 times (1 epoch = 10 iterations through the training set = 10x6588 training samples)')
  for i=1,9 do
    for j=1,6588 do
      d[#d+1]=d[j]
    end
  end
  d=tnt.ListDataset(d,function(x) return x end):shuffle()
  function d:manualSeed(seed) torch.manualSeed(seed) end
  d=tnt.BatchDataset(d,108,function(idx, size) return idx end,function (tbl)
      input=scn.InputBatch(2,spatialSize)
      local center=spatialSize:float():view(1,2)/2
      local p=torch.LongTensor(2)
      local v=torch.FloatTensor({1,0,0})
      for _,char in ipairs(tbl.input) do
        input:addSample()
        local m=torch.eye(2):float()
        local r=torch.random(3)
        local alpha=torch.uniform(-0.2,0.2)
        if alpha==1 then
          m[1][2]=alpha
        elseif alpha==2 then
          m[2][1]=alpha
        else
          m=torch.mm(m,torch.FloatTensor({
                {math.cos(alpha),math.sin(alpha)},
                {-math.sin(alpha),math.cos(alpha)}}))
        end
        c=(center+torch.FloatTensor(1,2):uniform(-8,8))
        for _,stroke in ipairs(char) do
          stroke=stroke:float()/255-0.5
          stroke=c:expandAs(stroke)+stroke*m*(Scale-0.01)
          --------------------------------------------------------------
          -- Draw pen stroke:
          scn.C.dimensionFn(2,'drawCurve')(
            input.metadata.ffi,input.features:cdata(),stroke:cdata())
          --------------------------------------------------------------
          -- Above is equivalent to :
          -- local x1,x2,y1,y2,l=0,stroke[1][1],0,stroke[1][2],0
          -- for i=2,stroke:size(1) do
          -- x1=x2
          -- y1=y2
          -- x2=stroke[i][1]
          -- y2=stroke[i][2]
          -- l=1e-10+((x2-x1)^2+(y2-y1)^2)^0.5
          -- v[2]=(x2-x1)/l
          -- v[3]=(y2-y1)/l
          -- l=math.max(x2-x1,y2-y1,x1-x2,y1-y2,0.9)
          -- for j=0,1,1/l do
          -- p[1]=math.floor(x1*j+x2*(1-j))
          -- p[2]=math.floor(y1*j+y2*(1-j))
          -- input:setLocation(p,v,false)
          -- end
          -- end
          --------------------------------------------------------------
        end
      end
      input:precomputeMetadata(precomputeStride)
      return {input=input,target=torch.LongTensor(tbl.target)}
    end
  )
  d=tnt.ParallelDatasetIterator({
      init = function() require 'torchnet'; scn=require 'sparseconvnet' end,
      nthread = 10,
      closure = function() return d end,
      ordered = true})
  return function(epoch)
    d:exec('manualSeed', epoch)
    d:exec('resample')
    return d()
  end
end
function val(spatialSize,Scale,precomputeStride)
  local d=torch.load('t7/test.t7')
  d=tnt.ListDataset(d,function(x) return x end)
  d=tnt.BatchDataset(d,183,function(idx, size) return idx end,function (tbl)
      input=scn.InputBatch(2,spatialSize)
      local center=spatialSize:float():view(1,2)/2
      local p=torch.LongTensor(2)
      local v=torch.FloatTensor({1,0,0})
      for _,char in ipairs(tbl.input) do
        input:addSample()
        for _,stroke in ipairs(char) do
          stroke=stroke:float()/255-0.5
          stroke=center:expandAs(stroke)+stroke*(Scale-0.01)
          scn.C.dimensionFn(2,'drawCurve')(input.metadata.ffi,input.features:cdata(),stroke:cdata())
        end
      end
      input:precomputeMetadata(precomputeStride)
      return {input=input,target=torch.LongTensor(tbl.target)}
    end
  )
  d=tnt.ParallelDatasetIterator({
      init = function() require 'torchnet'; scn=require 'sparseconvnet' end,
      nthread = 10,
      closure = function() return d end,
      ordered = true})
  return function()
    return d()
  end
end

return function(...)
  return {train=train(...), val=val(...)}
end