Commit fc961107 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Allow setLocations to include additional sampleIdx column in locations

parent d796a754
...@@ -86,7 +86,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m, ...@@ -86,7 +86,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
} }
if (locations->size[1] == Dimension + 1) { if (locations->size[1] == Dimension + 1) {
// add new samples to batch as necessary // add new samples to batch as necessary
auto SGs = *_m.inputSGs; auto &SGs = *_m.inputSGs;
for (uInt idx = 0; idx < locations->size[0]; ++idx) { for (uInt idx = 0; idx < locations->size[0]; ++idx) {
for (int d = 0; d < Dimension; ++d) for (int d = 0; d < Dimension; ++d)
p[d] = *l++; p[d] = *l++;
......
...@@ -33,7 +33,8 @@ class InputBatch(SparseConvNetTensor): ...@@ -33,7 +33,8 @@ class InputBatch(SparseConvNetTensor):
self.metadata.ffi, self.features, location, vector, overwrite) self.metadata.ffi, self.features, location, vector, overwrite)
def setLocations(self, locations, vectors, overwrite=False): def setLocations(self, locations, vectors, overwrite=False):
assert locations.min() >= 0 and (self.spatial_size.expand_as(locations) - locations).min() > 0 l =locations.narrow(1,0,self.dimension)
assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0
dim_fn(self.dimension, 'setInputSpatialLocations')( dim_fn(self.dimension, 'setInputSpatialLocations')(
self.metadata.ffi, self.features, locations, vectors, overwrite) self.metadata.ffi, self.features, locations, vectors, overwrite)
......
...@@ -15,7 +15,7 @@ return function(sparseconvnet) ...@@ -15,7 +15,7 @@ return function(sparseconvnet)
self.spatialSize = type(spatialSize)=='number' and torch.LongTensor( self.spatialSize = type(spatialSize)=='number' and torch.LongTensor(
dimension):fill(spatialSize) or spatialSize dimension):fill(spatialSize) or spatialSize
C.dimensionFn(self.dimension,'setInputSpatialSize')(self.metadata.ffi, C.dimensionFn(self.dimension,'setInputSpatialSize')(self.metadata.ffi,
self.spatialSize:cdata()) self.spatialSize:cdata())
end end
function InputBatch:addSample() function InputBatch:addSample()
C.dimensionFn(self.dimension, 'batchAddSample')(self.metadata.ffi) C.dimensionFn(self.dimension, 'batchAddSample')(self.metadata.ffi)
...@@ -28,7 +28,7 @@ return function(sparseconvnet) ...@@ -28,7 +28,7 @@ return function(sparseconvnet)
end end
function InputBatch:setLocation(location, vector, overwrite) function InputBatch:setLocation(location, vector, overwrite)
--[[location is a self.dimensional length set of coordinates: --[[location is a self.dimensional length set of coordinates:
torch.LongStorage or a table]] torch.LongStorage or a table]]
if type(location)=='table' then if type(location)=='table' then
local l=torch.LongStorage(self.dimension) local l=torch.LongStorage(self.dimension)
for i,x in ipairs(location) do for i,x in ipairs(location) do
...@@ -38,19 +38,20 @@ return function(sparseconvnet) ...@@ -38,19 +38,20 @@ return function(sparseconvnet)
end end
assert(location:min()>=0 and (self.spatialSize-location):min()>0) assert(location:min()>=0 and (self.spatialSize-location):min()>0)
C.dimensionFn(self.dimension,'setInputSpatialLocation')(self.metadata.ffi, C.dimensionFn(self.dimension,'setInputSpatialLocation')(self.metadata.ffi,
self.features:cdata(), location:cdata(), vector:cdata(), overwrite) self.features:cdata(), location:cdata(), vector:cdata(), overwrite)
end end
function InputBatch:setLocations(locations, vectors, overwrite) function InputBatch:setLocations(locations, vectors, overwrite)
--[[locations is a n_locations x self.dimensional length set of coordinates: --[[locations is a n_locations x self.dimensional length set of coordinates:
torch.LongStorage or a 2-D table]] torch.LongStorage or a 2-D table]]
if type(locations)=='table' then if type(locations)=='table' then
locations = torch.LongStorage(locations) locations = torch.LongStorage(locations)
end end
assert(locations:min()>=0 and (self.spatialSize:view(1, self.dimension):expandAs(locations)-locations):min()>0) local l = locations:narrow(2,1,self.dimension)
assert(l:min()>=0 and (self.spatialSize:view(1, self.dimension):expandAs(l)-l):min()>0)
C.dimensionFn(self.dimension,'setInputSpatialLocations')(self.metadata.ffi, C.dimensionFn(self.dimension,'setInputSpatialLocations')(self.metadata.ffi,
self.features:cdata(), locations:cdata(), vectors:cdata(), overwrite) self.features:cdata(), locations:cdata(), vectors:cdata(), overwrite)
end end
function InputBatch:precomputeMetadata(stride) function InputBatch:precomputeMetadata(stride)
if stride==2 then if stride==2 then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment