Commit 6de372c3 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

add batch information to getSpatialLocations

parent 96addd52
......@@ -63,9 +63,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
assert((locations->size[1] == Dimension or
locations->size[1] == 1 + Dimension) and
"locations.size(0) must be either Dimension or Dimension+1");
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
Point<Dimension> p;
auto &nActive = *_m.inputNActive;
auto nPlanes = vecs->size[1];
......@@ -73,7 +71,8 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
auto v = THFloatTensor_data(vecs);
if (locations->size[1] == Dimension) {
assert(_m.inputSG); // add points to current sample
// add points to current sample
assert(_m.inputSG);
auto &mp = _m.inputSG->mp;
for (uInt idx = 0; idx < locations->size[0]; ++idx) {
for (int d = 0; d < Dimension; ++d)
......@@ -101,9 +100,9 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
}
}
extern "C" void scn_D_(getSpatialLocations)(void **m,
THLongTensor *spatialSize,
THLongTensor *locations) {
extern "C" void scn_D_(getSpatialLocations)(void **m, THLongTensor *spatialSize,
THLongTensor *locations,
THLongTensor *batchIdxs) {
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
uInt nActive = _m.getNActive(spatialSize);
auto &SGs = _m.getSparseGrid(spatialSize);
......@@ -125,9 +124,9 @@ extern "C" void scn_D_(getSpatialLocations)(void **m,
}
}
extern "C" void
scn_D_(createMetadataForDenseToSparse)(void **m, THLongTensor *spatialSize_,
THLongTensor *pad_,
THLongTensor *nz_, long batchSize) {
scn_D_(createMetadataForDenseToSparse)(void **m, THLongTensor *spatialSize_,
THLongTensor *pad_, THLongTensor *nz_,
long batchSize) {
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
_m.setInputSpatialSize(spatialSize_);
_m.inputSGs->resize(batchSize);
......
This diff is collapsed.
This diff is collapsed.
......@@ -33,7 +33,25 @@ class InputBatch(SparseConvNetTensor):
self.metadata.ffi, self.features, location, vector, overwrite)
def setLocations(self, locations, vectors, overwrite=False):
l =locations.narrow(1,0,self.dimension)
"""
To set n locations in d dimensions, locations can be
- A size (n,d) LongTensor, giving d-dimensional coordinates -- points
are added to the current sample, or
- A size (n,d+1) LongTensor; the extra column specifies the sample
number (within the minibatch of samples).
Example with d=3 and n=2:
Set
locations = LongTensor([[1,2,3],
[4,5,6]])
to add points to the current sample at (1,2,3) and (4,5,6).
Set
locations = LongTensor([[1,2,3,7],
[4,5,6,9]])
to add point (1,2,3) to sample 7, and (4,5,6) to sample 9 (0-indexed).
"""
l = locations.narrow(1,0,self.dimension)
assert l.min() >= 0 and (self.spatial_size.expand_as(l) - l).min() > 0
dim_fn(self.dimension, 'setInputSpatialLocations')(
self.metadata.ffi, self.features, locations, vectors, overwrite)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment