Commit 6de372c3 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

add batch information to getSpatialLocations

parent 96addd52
...@@ -63,9 +63,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m, ...@@ -63,9 +63,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
assert((locations->size[1] == Dimension or assert((locations->size[1] == Dimension or
locations->size[1] == 1 + Dimension) and locations->size[1] == 1 + Dimension) and
"locations.size(0) must be either Dimension or Dimension+1"); "locations.size(0) must be either Dimension or Dimension+1");
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m) SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
Point<Dimension> p; Point<Dimension> p;
auto &nActive = *_m.inputNActive; auto &nActive = *_m.inputNActive;
auto nPlanes = vecs->size[1]; auto nPlanes = vecs->size[1];
...@@ -73,7 +71,8 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m, ...@@ -73,7 +71,8 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
auto v = THFloatTensor_data(vecs); auto v = THFloatTensor_data(vecs);
if (locations->size[1] == Dimension) { if (locations->size[1] == Dimension) {
assert(_m.inputSG); // add points to current sample // add points to current sample
assert(_m.inputSG);
auto &mp = _m.inputSG->mp; auto &mp = _m.inputSG->mp;
for (uInt idx = 0; idx < locations->size[0]; ++idx) { for (uInt idx = 0; idx < locations->size[0]; ++idx) {
for (int d = 0; d < Dimension; ++d) for (int d = 0; d < Dimension; ++d)
...@@ -101,9 +100,9 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m, ...@@ -101,9 +100,9 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
} }
} }
extern "C" void scn_D_(getSpatialLocations)(void **m, extern "C" void scn_D_(getSpatialLocations)(void **m, THLongTensor *spatialSize,
THLongTensor *spatialSize, THLongTensor *locations,
THLongTensor *locations) { THLongTensor *batchIdxs) {
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m) SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
uInt nActive = _m.getNActive(spatialSize); uInt nActive = _m.getNActive(spatialSize);
auto &SGs = _m.getSparseGrid(spatialSize); auto &SGs = _m.getSparseGrid(spatialSize);
...@@ -125,9 +124,9 @@ extern "C" void scn_D_(getSpatialLocations)(void **m, ...@@ -125,9 +124,9 @@ extern "C" void scn_D_(getSpatialLocations)(void **m,
} }
} }
extern "C" void extern "C" void
scn_D_(createMetadataForDenseToSparse)(void **m, THLongTensor *spatialSize_, scn_D_(createMetadataForDenseToSparse)(void **m, THLongTensor *spatialSize_,
THLongTensor *pad_, THLongTensor *pad_, THLongTensor *nz_,
THLongTensor *nz_, long batchSize) { long batchSize) {
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m) SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
_m.setInputSpatialSize(spatialSize_); _m.setInputSpatialSize(spatialSize_);
_m.inputSGs->resize(batchSize); _m.inputSGs->resize(batchSize);
......
This diff is collapsed.
This diff is collapsed.
...@@ -33,7 +33,25 @@ class InputBatch(SparseConvNetTensor): ...@@ -33,7 +33,25 @@ class InputBatch(SparseConvNetTensor):
self.metadata.ffi, self.features, location, vector, overwrite) self.metadata.ffi, self.features, location, vector, overwrite)
def setLocations(self, locations, vectors, overwrite=False): def setLocations(self, locations, vectors, overwrite=False):
l =locations.narrow(1,0,self.dimension) """
To set n locations in d dimensions, locations can be
- A size (n,d) LongTensor, giving d-dimensional coordinates -- points
are added to the current sample, or
- A size (n,d+1) LongTensor; the extra column specifies the sample
number (within the minibatch of samples).
Example with d=3 and n=2:
Set
locations = LongTensor([[1,2,3],
[4,5,6]])
to add points to the current sample at (1,2,3) and (4,5,6).
Set
locations = LongTensor([[1,2,3,7],
[4,5,6,9]])
to add point (1,2,3) to sample 7, and (4,5,6) to sample 9 (0-indexed).
"""
l = locations.narrow(1,0,self.dimension)
assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0 assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0
dim_fn(self.dimension, 'setInputSpatialLocations')( dim_fn(self.dimension, 'setInputSpatialLocations')(
self.metadata.ffi, self.features, locations, vectors, overwrite) self.metadata.ffi, self.features, locations, vectors, overwrite)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment