Commit 96addd52 authored by Benjamin Graham's avatar Benjamin Graham
Browse files

add batch information to getSpatialLocations

parent 86f46cb1
......@@ -100,6 +100,7 @@ extern "C" void scn_D_(setInputSpatialLocations)(void **m,
}
}
}
extern "C" void scn_D_(getSpatialLocations)(void **m,
THLongTensor *spatialSize,
THLongTensor *locations) {
......@@ -108,7 +109,7 @@ extern "C" void scn_D_(getSpatialLocations)(void **m,
auto &SGs = _m.getSparseGrid(spatialSize);
uInt batchSize = SGs.size();
THLongTensor_resize2d(locations, nActive, Dimension);
THLongTensor_resize2d(locations, nActive, Dimension + 1);
THLongTensor_zero(locations);
auto lD = THLongTensor_data(locations);
......@@ -117,8 +118,9 @@ extern "C" void scn_D_(getSpatialLocations)(void **m,
auto mp = SGs[i].mp;
for (auto it = mp.begin(); it != mp.end(); ++it) {
for (uInt d = 0; d < Dimension; ++d) {
lD[it->second * Dimension + d] = it->first[d];
lD[it->second * (Dimension + 1) + d] = it->first[d];
}
lD[it->second * (Dimension + 1) + Dimension] = i;
}
}
}
......
......@@ -49,6 +49,8 @@ return function (sparseconvnet)
void scn_DIMENSION_setInputSpatialLocation(void **m,
THFloatTensor *features, THLongTensor *location, THFloatTensor *vec,
bool overwrite);
void scn_5_getSpatialLocations(void **m,
THLongTensor *spatialSize, THLongTensor *locations);
void scn_DIMENSION_setInputSpatialLocations(void **m,
THFloatTensor *features, THLongTensor *locations, THFloatTensor *vecs,
bool overwrite);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment