data.lua 4.42 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

-- CASIA-OLHWDB1.1 dataset, 3755 character classes
-- 898573 training samples (240 writers), 224559 test samples (60 writers)
--data from http://www.nlpr.ia.ac.cn/databases/handwriting/Home.html

-- Consider moving the data folder to a ramdisk, i.e.
-- mkdir /ramdisk
-- mount -t tmpfs -o size=5g tmpfs /ramdisk
-- rsync -r t7/ /ramdisk/t7/
-- change t7/ to /ramdisk/t7/ below (twice)

tnt=require 'torchnet'
scn=require 'sparseconvnet'
require 'paths'

if not paths.dirp('t7/') then
  if not paths.filep('OLHWDB1.1trn_pot.zip') then
    print('Downloading data ...')
    os.execute('wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1trn_pot.zip')
    os.execute('wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1tst_pot.zip')
    os.execute('mkdir -p t7/train/ t7/test/ POT/')
    os.execute('unzip OLHWDB1.1trn_pot.zip -d POT/')
    os.execute('unzip OLHWDB1.1tst_pot.zip -d POT/')
  end
  print('Preprocessing data ...')
  os.execute('python3 readPotFiles.py')
  dofile('process.lua')
end

function train(spatialSize,Scale,precomputeStride)
  local d={}
  for x=1,898573
  do
    d[x]=x
  end
  d=tnt.ListDataset(d,function(x) return torch.load('t7/train/'..x..'.t7') end):shuffle()
  function d:manualSeed(seed) torch.manualSeed(seed) end
  d=tnt.BatchDataset(d,100,function(idx, size) return idx end,function (tbl)
      input=scn.InputBatch(2,spatialSize)
      local center=spatialSize:float():view(1,2)/2
      local p=torch.LongTensor(2)
      local v=torch.FloatTensor({1,0,0})
      for _,char in ipairs(tbl.input) do
        input:addSample()
        for _,stroke in ipairs(char) do
          stroke=stroke:float()/255-0.5
          stroke=center:expandAs(stroke)+stroke*(Scale-0.01)
          --------------------------------------------------------------
          -- Draw pen stroke:
          scn.C.dimensionFn(2,'drawCurve')(
            input.metadata.ffi,input.features:cdata(),stroke:cdata())
          --------------------------------------------------------------
          -- Above is equivalent to :
          -- local x1,x2,y1,y2,l=0,stroke[1][1],0,stroke[1][2],0
          -- for i=2,stroke:size(1) do
          -- x1=x2
          -- y1=y2
          -- x2=stroke[i][1]
          -- y2=stroke[i][2]
          -- l=1e-10+((x2-x1)^2+(y2-y1)^2)^0.5
          -- v[2]=(x2-x1)/l
          -- v[3]=(y2-y1)/l
          -- l=math.max(x2-x1,y2-y1,x1-x2,y1-y2,0.9)
          -- for j=0,1,1/l do
          -- p[1]=math.floor(x1*j+x2*(1-j))
          -- p[2]=math.floor(y1*j+y2*(1-j))
          -- input:setLocation(p,v,false)
          -- end
          -- end
          --------------------------------------------------------------
        end
      end
      input:precomputeMetadata(precomputeStride)
      return {input=input,target=torch.LongTensor(tbl.target)}
    end
  )
  d=tnt.ParallelDatasetIterator({
      init = function() require 'torchnet'; scn=require 'sparseconvnet' end,
      nthread = 10,
      closure = function() return d end,
      ordered = true})
  return function(epoch)
    d:exec('manualSeed', epoch)
    d:exec('resample')
    return d()
  end
end
function val(spatialSize,Scale,precomputeStride)
  local d={}
  for x=1,224559
  do
    d[x]=x
  end
  d=tnt.ListDataset(d,function(x) return torch.load('t7/test/'..x..'.t7') end)
  d=tnt.BatchDataset(d,100,function(idx, size) return idx end,function (tbl)
      input=scn.InputBatch(2,spatialSize)
      local center=spatialSize:float():view(1,2)/2
      local p=torch.LongTensor(2)
      local v=torch.FloatTensor({1,0,0})
      for _,char in ipairs(tbl.input) do
        input:addSample()
        for _,stroke in ipairs(char) do
          stroke=stroke:float()/255-0.5
          stroke=center:expandAs(stroke)+stroke*(Scale-0.01)
          scn.C.dimensionFn(2,'drawCurve')(
            input.metadata.ffi,input.features:cdata(),stroke:cdata())
        end
      end
      input:precomputeMetadata(precomputeStride)
      return {input=input,target=torch.LongTensor(tbl.target)}
    end
  )
  d=tnt.ParallelDatasetIterator({
      init = function() require 'torchnet'; scn=require 'sparseconvnet' end,
      nthread = 10,
      closure = function() return d end,
      ordered = true})
  return function()
    return d()
  end
end

return function(...)
  return {train=train(...), val=val(...)}
end