Commit fc961107 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Allow setLocations to include additional sampleIdx column in locations

parent d796a754
...@@ -86,7 +86,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m, ...@@ -86,7 +86,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
} }
if (locations->size[1] == Dimension + 1) { if (locations->size[1] == Dimension + 1) {
// add new samples to batch as necessary // add new samples to batch as necessary
auto SGs = *_m.inputSGs; auto &SGs = *_m.inputSGs;
for (uInt idx = 0; idx < locations->size[0]; ++idx) { for (uInt idx = 0; idx < locations->size[0]; ++idx) {
for (int d = 0; d < Dimension; ++d) for (int d = 0; d < Dimension; ++d)
p[d] = *l++; p[d] = *l++;
......
...@@ -33,7 +33,8 @@ class InputBatch(SparseConvNetTensor): ...@@ -33,7 +33,8 @@ class InputBatch(SparseConvNetTensor):
self.metadata.ffi, self.features, location, vector, overwrite) self.metadata.ffi, self.features, location, vector, overwrite)
def setLocations(self, locations, vectors, overwrite=False): def setLocations(self, locations, vectors, overwrite=False):
assert locations.min() >= 0 and (self.spatial_size.expand_as(locations) - locations).min() > 0 l =locations.narrow(1,0,self.dimension)
assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0
dim_fn(self.dimension, 'setInputSpatialLocations')( dim_fn(self.dimension, 'setInputSpatialLocations')(
self.metadata.ffi, self.features, locations, vectors, overwrite) self.metadata.ffi, self.features, locations, vectors, overwrite)
......
...@@ -47,7 +47,8 @@ return function(sparseconvnet) ...@@ -47,7 +47,8 @@ return function(sparseconvnet)
locations = torch.LongStorage(locations) locations = torch.LongStorage(locations)
end end
assert(locations:min()>=0 and (self.spatialSize:view(1, self.dimension):expandAs(locations)-locations):min()>0) local l = locations:narrow(2,1,self.dimension)
assert(l:min()>=0 and (self.spatialSize:view(1, self.dimension):expandAs(l)-l):min()>0)
C.dimensionFn(self.dimension,'setInputSpatialLocations')(self.metadata.ffi, C.dimensionFn(self.dimension,'setInputSpatialLocations')(self.metadata.ffi,
self.features:cdata(), locations:cdata(), vectors:cdata(), overwrite) self.features:cdata(), locations:cdata(), vectors:cdata(), overwrite)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment