Commit 96addd52 authored by Benjamin Graham's avatar Benjamin Graham
Browse files

add batch information to getSpatialLocations

parent 86f46cb1
...@@ -100,6 +100,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m, ...@@ -100,6 +100,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
} }
} }
} }
extern "C" void scn_D_(getSpatialLocations)(void **m, extern "C" void scn_D_(getSpatialLocations)(void **m,
THLongTensor *spatialSize, THLongTensor *spatialSize,
THLongTensor *locations) { THLongTensor *locations) {
...@@ -108,7 +109,7 @@ extern "C" void scn_D_(getSpatialLocations)(void **m, ...@@ -108,7 +109,7 @@ extern "C" void scn_D_(getSpatialLocations)(void **m,
auto &SGs = _m.getSparseGrid(spatialSize); auto &SGs = _m.getSparseGrid(spatialSize);
uInt batchSize = SGs.size(); uInt batchSize = SGs.size();
THLongTensor_resize2d(locations, nActive, Dimension); THLongTensor_resize2d(locations, nActive, Dimension + 1);
THLongTensor_zero(locations); THLongTensor_zero(locations);
auto lD = THLongTensor_data(locations); auto lD = THLongTensor_data(locations);
...@@ -117,8 +118,9 @@ extern "C" void scn_D_(getSpatialLocations)(void **m, ...@@ -117,8 +118,9 @@ extern "C" void scn_D_(getSpatialLocations)(void **m,
auto mp = SGs[i].mp; auto mp = SGs[i].mp;
for (auto it = mp.begin(); it != mp.end(); ++it) { for (auto it = mp.begin(); it != mp.end(); ++it) {
for (uInt d = 0; d < Dimension; ++d) { for (uInt d = 0; d < Dimension; ++d) {
lD[it->second * Dimension + d] = it->first[d]; lD[it->second * (Dimension + 1) + d] = it->first[d];
} }
lD[it->second * (Dimension + 1) + Dimension] = i;
} }
} }
} }
......
...@@ -49,6 +49,8 @@ return function (sparseconvnet) ...@@ -49,6 +49,8 @@ return function (sparseconvnet)
void scn_DIMENSION_setInputSpatialLocation(void **m, void scn_DIMENSION_setInputSpatialLocation(void **m,
THFloatTensor *features, THLongTensor *location, THFloatTensor *vec, THFloatTensor *features, THLongTensor *location, THFloatTensor *vec,
bool overwrite); bool overwrite);
void scn_5_getSpatialLocations(void **m,
THLongTensor *spatialSize, THLongTensor *locations);
void scn_DIMENSION_setInputSpatialLocations(void **m, void scn_DIMENSION_setInputSpatialLocations(void **m,
THFloatTensor *features, THLongTensor *locations, THFloatTensor *vecs, THFloatTensor *features, THLongTensor *locations, THFloatTensor *vecs,
bool overwrite); bool overwrite);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment