"...llm/count/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "d38325c21f932395db5428fe463f8856a4e3afc1"
InputBatch.lua 2.8 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

return function(sparseconvnet)
  local C = sparseconvnet.C
  local InputBatch, parent = torch.class('sparseconvnet.InputBatch', sparseconvnet)

  function InputBatch:__init(dimension, spatialSize)
    self.dimension = dimension
    self.features = torch.Tensor():type('torch.FloatTensor')
    self.metadata = sparseconvnet.Metadata(dimension)
    self.spatialSize = type(spatialSize)=='number' and torch.LongTensor(
      dimension):fill(spatialSize) or spatialSize
    C.dimensionFn(self.dimension,'setInputSpatialSize')(self.metadata.ffi,
18
                                                        self.spatialSize:cdata())
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
19
20
21
22
23
24
25
26
27
28
29
30
  end
  function InputBatch:addSample()
    C.dimensionFn(self.dimension, 'batchAddSample')(self.metadata.ffi)
  end
  function InputBatch:addSampleFromTensor(tensor,offset,threshold)
    C.dimensionFn(
      self.dimension,'addSampleFromThresholdedTensor')(
      self.metadata.ffi, self.features:cdata(), tensor:cdata(), offset:cdata(),
      self.spatialSize:cdata(), threshold)
  end
  function InputBatch:setLocation(location, vector, overwrite)
    --[[location is a self.dimensional length set of coordinates:
31
      torch.LongStorage or a table]]
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
32
33
34
35
36
37
38
39
40
    if type(location)=='table' then
      local l=torch.LongStorage(self.dimension)
      for i,x in ipairs(location) do
        l[i]=x
      end
      location = l
    end
    assert(location:min()>=0 and (self.spatialSize-location):min()>0)
    C.dimensionFn(self.dimension,'setInputSpatialLocation')(self.metadata.ffi,
41
                                                            self.features:cdata(), location:cdata(), vector:cdata(), overwrite)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
42
  end
Ed Ng's avatar
Ed Ng committed
43
44
  function InputBatch:setLocations(locations, vectors, overwrite)
    --[[locations is a n_locations x self.dimensional length set of coordinates:
45
      torch.LongStorage or a 2-D table]]
Ed Ng's avatar
Ed Ng committed
46
47
48
49
    if type(locations)=='table' then
      locations = torch.LongStorage(locations)
    end

50
51
    local l = locations:narrow(2,1,self.dimension)
    assert(l:min()>=0 and (self.spatialSize:view(1, self.dimension):expandAs(l)-l):min()>0)
Ed Ng's avatar
Ed Ng committed
52
53

    C.dimensionFn(self.dimension,'setInputSpatialLocations')(self.metadata.ffi,
54
                                                             self.features:cdata(), locations:cdata(), vectors:cdata(), overwrite)
Ed Ng's avatar
Ed Ng committed
55
  end
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  function InputBatch:precomputeMetadata(stride)
    if stride==2 then
      C.dimensionFn(self.dimension,'generateRuleBooks2s2')(self.metadata.ffi)
    else
      C.dimensionFn(self.dimension,'generateRuleBooks3s2')(self.metadata.ffi)
    end
  end
  function InputBatch:type(t)
    if t then
      self.features = self.features:type(t)
    else
      return self.features:type()
    end
  end
end