Commit 94a39536 authored by Ed Ng's avatar Ed Ng
Browse files

Add Get Spatial Locations API

parent db6454cd
......@@ -44,7 +44,6 @@ extern "C" void scn_D_(setInputSpatialLocation)(void **m,
THFloatTensor_data(vec), sizeof(float) * nPlanes);
}
}
extern "C" void scn_D_(setInputSpatialLocations)(void **m,
THFloatTensor *features,
THLongTensor *locations,
......@@ -85,7 +84,28 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
THFloatTensor_free(vec);
}
}
extern "C" void scn_D_(getSpatialLocations)(void **m,
THLongTensor *spatialSize,
THLongTensor *locations) {
SCN_INITIALIZE_AND_REFERENCE(Metadata<Dimension>, m)
uInt nActive = _m.getNActive(spatialSize);
auto &SGs = _m.getSparseGrid(spatialSize);
uInt batchSize = SGs.size();
THLongTensor_resize2d(locations, nActive, Dimension);
THLongTensor_zero(locations);
auto lD = THLongTensor_data(locations);
for (uInt i = 0; i < batchSize; i++) {
auto mp = SGs[i].mp;
for (auto it = mp.begin(); it != mp.end(); ++it) {
for (uInt d = 0; d < Dimension; ++d) {
lD[it->second * Dimension + d] = it->first[d];
}
}
}
}
extern "C" void
scn_D_(createMetadataForDenseToSparse)(void **m, THLongTensor *spatialSize_,
THLongTensor *pad_,
......
......@@ -28,6 +28,9 @@ void scn_1_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_1_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_1_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_2_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -47,6 +50,9 @@ void scn_2_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_2_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_2_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_3_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -66,6 +72,9 @@ void scn_3_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_3_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_3_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_4_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -85,6 +94,9 @@ void scn_4_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_4_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_4_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_5_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -104,6 +116,9 @@ void scn_5_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_5_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_5_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_6_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -123,6 +138,9 @@ void scn_6_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_6_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_6_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_7_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -142,6 +160,9 @@ void scn_7_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_7_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_7_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_8_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -161,6 +182,9 @@ void scn_8_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_8_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_8_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_9_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -180,6 +204,9 @@ void scn_9_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_9_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_9_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
double scn_10_addSampleFromThresholdedTensor(void **m, THFloatTensor *features_,
THFloatTensor *tensor_,
THLongTensor *offset_,
......@@ -199,6 +226,9 @@ void scn_10_setInputSpatialLocation(void **m, THFloatTensor *features,
void scn_10_setInputSpatialLocations(void **m, THFloatTensor *features,
THLongTensor *locations, THFloatTensor *vecs,
_Bool overwrite);
void scn_10_getSpatialLocations(void **m,
THLongTensor *spatialSize,
THLongTensor *locations);
void scn_cpu_float_AffineReluTrivialConvolution_updateOutput(
THFloatTensor *input_features, THFloatTensor *output_features,
THFloatTensor *affineWeight, THFloatTensor *affineBias,
......
......@@ -4,6 +4,9 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from ..utils import dim_fn
class SparseConvNetTensor(object):
def __init__(self, features=None, metadata=None, spatial_size=None):
......@@ -11,6 +14,14 @@ class SparseConvNetTensor(object):
self.metadata = metadata
self.spatial_size = spatial_size
def getSpatialLocations(self, spatial_size=None):
if spatial_size is None:
spatial_size = self.spatial_size
t = torch.LongTensor()
dim_fn(self.metadata.dimension, 'getSpatialLocations')(self.metadata.ffi, spatial_size, t)
return t
def type(self, t=None):
if t:
self.features = self.features.type(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment