InputBatch.lua 2.07 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

return function(sparseconvnet)
  local C = sparseconvnet.C
  local InputBatch, parent = torch.class('sparseconvnet.InputBatch', sparseconvnet)

  function InputBatch:__init(dimension, spatialSize)
    self.dimension = dimension
    self.features = torch.Tensor():type('torch.FloatTensor')
    self.metadata = sparseconvnet.Metadata(dimension)
    self.spatialSize = type(spatialSize)=='number' and torch.LongTensor(
      dimension):fill(spatialSize) or spatialSize
    C.dimensionFn(self.dimension,'setInputSpatialSize')(self.metadata.ffi,
      self.spatialSize:cdata())
  end
  function InputBatch:addSample()
    C.dimensionFn(self.dimension, 'batchAddSample')(self.metadata.ffi)
  end
  function InputBatch:addSampleFromTensor(tensor,offset,threshold)
    C.dimensionFn(
      self.dimension,'addSampleFromThresholdedTensor')(
      self.metadata.ffi, self.features:cdata(), tensor:cdata(), offset:cdata(),
      self.spatialSize:cdata(), threshold)
  end
  function InputBatch:setLocation(location, vector, overwrite)
    --[[location is a self.dimensional length set of coordinates:
    torch.LongStorage or a table]]
    if type(location)=='table' then
      local l=torch.LongStorage(self.dimension)
      for i,x in ipairs(location) do
        l[i]=x
      end
      location = l
    end
    assert(location:min()>=0 and (self.spatialSize-location):min()>0)
    C.dimensionFn(self.dimension,'setInputSpatialLocation')(self.metadata.ffi,
      self.features:cdata(), location:cdata(), vector:cdata(), overwrite)
  end
  function InputBatch:precomputeMetadata(stride)
    if stride==2 then
      C.dimensionFn(self.dimension,'generateRuleBooks2s2')(self.metadata.ffi)
    else
      C.dimensionFn(self.dimension,'generateRuleBooks3s2')(self.metadata.ffi)
    end
  end
  function InputBatch:type(t)
    if t then
      self.features = self.features:type(t)
    else
      return self.features:type()
    end
  end
end