// Copyright 2016-present, Facebook, Inc. // All rights reserved. // // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. #include "Metadata.h" #include "ActivePoolingRules.h" #include "ConvolutionRules.h" #include "FullConvolutionRules.h" #include "IOLayersRules.h" #include "PermutohedralSubmanifoldConvolutionRules.h" #include "RandomizedStrideRules.h" #include "SubmanifoldConvolutionRules.h" template SparseGrid::SparseGrid() : ctr(0) { // Sparsehash needs a key to be set aside and never used Point empty_key; for (Int i = 0; i < dimension; ++i) empty_key[i] = std::numeric_limits::min(); mp.set_empty_key(empty_key); } template T *OptionalTensorData(at::Tensor tensor) { return tensor.numel() ? tensor.data() : nullptr; } template void addPointToSparseGridMapAndFeatures(SparseGridMap &mp, Point p, Int &nActive, long nPlanes, /*float*/ at::Tensor features, float *vec, bool overwrite) { auto iter = mp.find(p); if (iter == mp.end()) { iter = mp.insert(std::make_pair(p, nActive++)).first; features.resize_({(int)nActive, nPlanes}); std::memcpy(features.data() + (nActive - 1) * nPlanes, vec, sizeof(float) * nPlanes); } else if (overwrite) { std::memcpy(features.data() + iter->second * nPlanes, vec, sizeof(float) * nPlanes); } } template Metadata::Metadata() : re(std::chrono::system_clock::now().time_since_epoch().count()) {} template void Metadata::clear() { nActive.clear(); grids.clear(); activePoolingRuleBooks.clear(); inputLayerRuleBook.clear(); submanifoldRuleBooks.clear(); ruleBooks.clear(); fullConvolutionRuleBook.clear(); sparseToDenseRuleBooks.clear(); inputSGs = nullptr; inputSG = nullptr; inputNActive = nullptr; inputLayerRuleBook.clear(); blLayerRuleBook.clear(); } template Int Metadata::getNActive(/*long*/ at::Tensor spatialSize) { return nActive[LongTensorToPoint(spatialSize)]; }; template SparseGrids & Metadata::getSparseGrid(/*long*/ at::Tensor spatialSize) { return grids[LongTensorToPoint(spatialSize)]; }; template void Metadata::setInputSpatialSize(/*long*/ at::Tensor spatialSize) { inputSpatialSize = LongTensorToPoint(spatialSize); inputSGs = &grids[inputSpatialSize]; inputNActive = &nActive[inputSpatialSize]; } template void Metadata::batchAddSample() { assert(inputSGs && "Call setInputSpatialSize first, please!"); inputSGs->resize(inputSGs->size() + 1); inputSG = &inputSGs->back(); } template void Metadata::setInputSpatialLocation(/*float*/ at::Tensor features, /*long*/ at::Tensor location, /*float*/ at::Tensor vec, bool overwrite) { auto p = LongTensorToPoint(location); SparseGridMap &mp = inputSG->mp; Int &nActive = *inputNActive; auto nPlanes = vec.size(0); addPointToSparseGridMapAndFeatures( mp, p, nActive, nPlanes, features, vec.data(), overwrite); } template void Metadata::setInputSpatialLocations( /*float*/ at::Tensor features, /*long*/ at::Tensor locations, /*float*/ at::Tensor vecs, bool overwrite) { /* assert(locations.ndimension() == 2 and "locations must be 2 * dimensional!"); */ /* assert(vecs.ndimension() == 2 and "vecs must be 2 dimensional!"); */ /* assert(locations.size(0) == vecs.size(0) and */ /* "Location.size(0) and vecs.size(0) must be equal!"); */ /* assert((locations.size(1) == dimension or */ /* locations.size(1) == 1 + dimension) and */ /* "locations.size(0) must be either dimension or dimension+1"); */ Point p; Int &nActive = *inputNActive; auto nPlanes = vecs.size(1); long *l = locations.data(); float *v = vecs.data(); if (locations.size(1) == dimension) { // add points to current sample assert(inputSG); SparseGridMap &mp = inputSG->mp; for (Int idx = 0; idx < locations.size(0); ++idx) { for (Int d = 0; d < dimension; ++d) p[d] = *l++; addPointToSparseGridMapAndFeatures(mp, p, nActive, nPlanes, features, v, overwrite); v += nPlanes; } } if (locations.size(1) == dimension + 1) { // add new samples to batch as necessary auto &SGs = *inputSGs; for (Int idx = 0; idx < locations.size(0); ++idx) { for (Int d = 0; d < dimension; ++d) p[d] = *l++; Int batch = *l++; if (batch >= (Int)SGs.size()) { SGs.resize(batch + 1); } SparseGridMap &mp = SGs[batch].mp; addPointToSparseGridMapAndFeatures(mp, p, nActive, nPlanes, features, v, overwrite); v += nPlanes; } } } template at::Tensor Metadata::getSpatialLocations(/*long*/ at::Tensor spatialSize) { Int nActive = getNActive(spatialSize); auto &SGs = getSparseGrid(spatialSize); Int batchSize = SGs.size(); auto locations = torch::zeros({(int)nActive, dimension + 1}, at::kLong); auto lD = locations.data(); for (Int i = 0; i < batchSize; i++) { auto mp = SGs[i].mp; auto offset = SGs[i].ctr; for (auto it = mp.begin(); it != mp.end(); ++it) { for (Int d = 0; d < dimension; ++d) { lD[(it->second + offset) * (dimension + 1) + d] = it->first[d]; } lD[(it->second + offset) * (dimension + 1) + dimension] = i; } } return locations; } template void Metadata::createMetadataForDenseToSparse( /*long*/ at::Tensor spatialSize, /*long*/ at::Tensor nz_, long batchSize) { clear(); setInputSpatialSize(spatialSize); inputSGs->resize(batchSize); auto &nActive = *inputNActive; nActive = nz_.size(0); long *nz = nz_.data(); std::vector br(batchSize + 1); if (batchSize == 1) { br[1] = nActive; } else { long b = 0; for (Int i = 0; i < nActive; i++) { long B = nz[i * (dimension + 1)]; for (; b < B;) br[++b] = i; } for (; b < batchSize;) br[++b] = nActive; } Int b; #pragma omp parallel for private(b) for (b = 0; b < batchSize; b++) { auto &sg = inputSGs->at(b); for (Int i = br[b]; i < br[b + 1]; i++) { Point x; for (Int j = 0; j < dimension; j++) { x[j] = nz[i * (dimension + 1) + j + 1]; // 0-indexed } sg.mp[x] = i; } } } template void Metadata::sparsifyMetadata(Metadata &mOut, /*long*/ at::Tensor spatialSize, /*byte*/ at::Tensor filter, /*long*/ at::Tensor cuSum) { // Create a new SparseGrids with fewer entries. mOut.clear(); auto p = LongTensorToPoint(spatialSize); auto &sgsIn = grids[p]; auto &sgsOut = mOut.grids[p]; sgsOut.resize(sgsIn.size()); if (filter.ndimension() == 1) { auto f = filter.data(); auto cs = cuSum.data(); auto nActive = cs[cuSum.numel() - 1]; mOut.nActive[p] = nActive; Int sample; #pragma omp parallel for private(sample) for (sample = 0; sample < (Int)sgsIn.size(); ++sample) { auto &sgIn = sgsIn[sample]; auto &sgOut = sgsOut[sample]; for (auto const &iter : sgIn.mp) { auto n = iter.second + sgIn.ctr; if (f[n]) sgOut.mp[iter.first] = cs[n] - 1; } } } else { mOut.nActive[p] = 0; } } template void Metadata::appendMetadata(Metadata &mAdd, /*long*/ at::Tensor spatialSize) { auto p = LongTensorToPoint(spatialSize); auto &sgs1 = grids[p]; auto &sgs2 = mAdd.grids[p]; auto &nActive1 = nActive[p]; auto &nActive2 = mAdd.nActive[p]; Int bs1 = sgs1.size(); Int bs2 = sgs2.size(); sgs1.insert(sgs1.end(), sgs2.begin(), sgs2.end()); for (Int i = bs1; i < bs1 + bs2; ++i) sgs1[i].ctr += nActive1; nActive1 += nActive2; } template at::Tensor Metadata::sparsifyCompare(Metadata &mReference, Metadata &mSparsified, /*long*/ at::Tensor spatialSize) { auto p = LongTensorToPoint(spatialSize); at::Tensor delta = torch::zeros({nActive[p]}, at::kFloat); float *deltaPtr = delta.data(); auto &sgsReference = mReference.grids[p]; auto &sgsFull = grids[p]; auto &sgsSparsified = mSparsified.grids[p]; Int batchSize = sgsFull.size(); Int sample; #pragma omp parallel for private(sample) for (sample = 0; sample < (Int)batchSize; ++sample) { auto &sgReference = sgsReference[sample]; auto &sgFull = sgsFull[sample]; auto &sgSparsified = sgsSparsified[sample]; for (auto const &iter : sgFull.mp) { bool gt = sgReference.mp.find(iter.first) != sgReference.mp.end(); bool hot = sgSparsified.mp.find(iter.first) != sgSparsified.mp.end(); if (gt and not hot) deltaPtr[iter.second + sgFull.ctr] = -1; if (hot and not gt) deltaPtr[iter.second + sgFull.ctr] = +1; } } return delta; } // tensor is size[0] x .. x size[dimension-1] x size[dimension] // size[0] x .. x size[dimension-1] == spatial volume // size[dimension] == #feature planes template void Metadata::addSampleFromThresholdedTensor( /*float*/ at::Tensor features_, /*float*/ at::Tensor tensor_, /*long*/ at::Tensor offset_, /*long*/ at::Tensor spatialSize_, float threshold) { auto &nActive = *inputNActive; auto &SGs = *inputSGs; SGs.resize(SGs.size() + 1); auto &sg = SGs.back(); auto tensor = tensor_.data(); auto offset = offset_.data(); auto spatialSize = spatialSize_.data(); long size[dimension + 1]; // IntList? for (Int i = 0; i <= dimension; ++i) size[i] = tensor_.size(i); // std::vector size = tensor_.size(); auto nPlanes = size[dimension]; long volume = 1; for (Int i = 0; i < dimension; ++i) volume *= size[i]; features_.resize_({(int)(nActive + volume), nPlanes}); // Increment pointers as we work through the data auto features = features_.data() + nActive * nPlanes; // Active locations Point point; for (Int i = 0; i < dimension; i++) point[i] = offset[i]; for (Int ctr = 0; ctr < volume; ctr++) { bool active = false; for (Int i = 0; i < nPlanes; i++) { if (fabs(tensor[i]) > threshold) { active = true; break; } } for (Int i = 0; i < dimension; i++) { if (point[i] < 0 or point[i] >= spatialSize[i]) { active = false; break; } } if (active) { sg.mp[point] = nActive++; std::memcpy(features, tensor, sizeof(float) * nPlanes); features += nPlanes; } tensor += nPlanes; incrementPointInCube(point, size, offset); } features_.resize_({(int)nActive, nPlanes}); } // 3x3 submanifold convolutions, 3x3/2x2 pooling or strided convolutions template void Metadata::generateRuleBooks3s2() { long sz[dimension], str[dimension], inS[dimension], outS[dimension]; Point p1; Point<2 * dimension> p2; Point<3 * dimension> p3; for (Int i = 0; i < dimension; ++i) { p1[i] = p2[i] = p3[i] = inS[i] = inputSpatialSize[i]; p2[i + dimension] = p3[i + dimension] = sz[i] = 3; p3[i + 2 * dimension] = str[i] = 2; } while (true) { auto &SGs = grids[p1]; auto &rb = submanifoldRuleBooks[p2]; if (rb.empty()) SubmanifoldConvolution_SgsToRules(SGs, rb, sz); for (Int i = 0; i < dimension; ++i) if (p1[i] < 3 or p1[i] % 2 != 1) return; else p1[i] = outS[i] = (inS[i] - 1) / 2; auto &SGs2 = grids[p1]; auto &rb2 = ruleBooks[p3]; if (rb2.empty()) nActive[p1] = Convolution_InputSgsToRulesAndOutputSgs(SGs, SGs2, rb2, sz, str, inS, outS); for (Int i = 0; i < dimension; ++i) p2[i] = p3[i] = inS[i] = outS[i]; } } // 3x3 submanifold convolutions, 2x2 pooling or strided convolutions template void Metadata::generateRuleBooks2s2() { long s2[dimension], s3[dimension], inS[dimension], outS[dimension]; Point p1; Point<2 * dimension> p2; Point<3 * dimension> p3; for (Int i = 0; i < dimension; ++i) { p1[i] = p2[i] = p3[i] = inS[i] = inputSpatialSize[i]; p2[i + dimension] = s3[i] = 3; p3[i + dimension] = p3[i + 2 * dimension] = s2[i] = 2; } while (true) { auto &SGs = grids[p1]; auto &rb = submanifoldRuleBooks[p2]; if (rb.empty()) SubmanifoldConvolution_SgsToRules(SGs, rb, s3); for (Int i = 0; i < dimension; ++i) if (p1[i] < 2 or p1[i] % 2 != 0) return; else p1[i] = outS[i] = inS[i] / 2; auto &SGs2 = grids[p1]; auto &rb2 = ruleBooks[p3]; if (rb2.empty()) nActive[p1] = Convolution_InputSgsToRulesAndOutputSgs(SGs, SGs2, rb2, s2, s2, inS, outS); for (Int i = 0; i < dimension; ++i) p2[i] = p3[i] = inS[i] = outS[i]; } } template void Metadata::inputLayer(/*long*/ at::Tensor spatialSize, /*long*/ at::Tensor coords, Int batchSize, Int mode) { assert(spatialSize.ndimension() == 1); assert(spatialSize.size(0) == dimension); assert(coords.ndimension() == 2); assert(coords.size(1) >= dimension and coords.size(1) <= dimension + 1); setInputSpatialSize(spatialSize); inputLayerRules(*inputSGs, inputLayerRuleBook, coords.data(), coords.size(0), coords.size(1), batchSize, mode, *inputNActive); } template void Metadata::blLayer(/*long*/ at::Tensor spatialSize, /*long*/ at::Tensor coords, Int mode) { assert(spatialSize.ndimension() == 1); assert(spatialSize.size(0) == dimension); assert(coords.ndimension() == 3); assert(coords.size(2) == dimension); setInputSpatialSize(spatialSize); blRules(*inputSGs, blLayerRuleBook, coords.data(), coords.size(0), coords.size(1), mode, *inputNActive); } template RuleBook &Metadata::getSubmanifoldRuleBook( /*long*/ at::Tensor spatialSize, /*long*/ at::Tensor size, bool openMP) { auto p = TwoLongTensorsToPoint(spatialSize, size); auto &rb = submanifoldRuleBooks[p]; if (rb.empty()) { auto &SGs = grids[LongTensorToPoint(spatialSize)]; #if defined(ENABLE_OPENMP) openMP ? SubmanifoldConvolution_SgsToRules_OMP(SGs, rb, size.data()) : #endif SubmanifoldConvolution_SgsToRules(SGs, rb, size.data()); } return rb; } template RuleBook &Metadata::getPermutohedralSubmanifoldRuleBook( /*long*/ at::Tensor spatialSize, bool openMP) { auto p = LongTensorToPoint(spatialSize); auto &rb = permutohedralRuleBooks[p]; if (rb.empty()) { auto &SGs = grids[LongTensorToPoint(spatialSize)]; #if defined(ENABLE_OPENMP) openMP ? PermutohedralSubmanifoldConvolution_SgsToRules_OMP(SGs, rb) : #endif PermutohedralSubmanifoldConvolution_SgsToRules(SGs, rb); } return rb; } template RuleBook &Metadata::getActivePoolingRuleBook( /*long*/ at::Tensor spatialSize) { auto spatialSz = LongTensorToPoint(spatialSize); auto &SGs = grids[spatialSz]; auto &rb = activePoolingRuleBooks[spatialSz]; if (rb.empty()) activePoolingRules(SGs, rb); return rb; } template RuleBook &Metadata::getSparseToDenseRuleBook( /*long*/ at::Tensor spatialSize, bool openMP) { auto ss = LongTensorToPoint(spatialSize); auto &SGs = grids[ss]; auto &rb = sparseToDenseRuleBooks[ss]; if (rb.empty()) #if defined(ENABLE_OPENMP) openMP ? SparseToDense_InputSgsToRulesAndOutputSgs_OMP( SGs, rb, spatialSize.data()) : #endif SparseToDense_InputSgsToRulesAndOutputSgs(SGs, rb, spatialSize.data()); return rb; } template RuleBook &Metadata::getRuleBook( /*long*/ at::Tensor inputSpatialSize, /*long*/ at::Tensor outputSpatialSize, /*long*/ at::Tensor size, /*long*/ at::Tensor stride, bool openMP) { auto p = ThreeLongTensorsToPoint(inputSpatialSize, size, stride); auto &rb = ruleBooks[p]; if (rb.empty()) { auto iS = LongTensorToPoint(inputSpatialSize); auto oS = LongTensorToPoint(outputSpatialSize); auto &iSGs = grids[iS]; auto &oSGs = grids[oS]; nActive[oS] = #if defined(ENABLE_OPENMP) openMP ? Convolution_InputSgsToRulesAndOutputSgs_OMP( iSGs, oSGs, rb, size.data(), stride.data(), inputSpatialSize.data(), outputSpatialSize.data()) : #endif Convolution_InputSgsToRulesAndOutputSgs( iSGs, oSGs, rb, size.data(), stride.data(), inputSpatialSize.data(), outputSpatialSize.data()); } return rb; } template RuleBook &Metadata::getFullConvolutionRuleBook( /*long*/ at::Tensor inputSpatialSize, /*long*/ at::Tensor outputSpatialSize, /*long*/ at::Tensor size, /*long*/ at::Tensor stride, Metadata &newM) { auto &rb = newM.fullConvolutionRuleBook; if (rb.empty()) { newM.clear(); auto iS = LongTensorToPoint(inputSpatialSize); auto oS = LongTensorToPoint(outputSpatialSize); newM.grids[iS] = grids[iS]; // copy newM.nActive[iS] = nActive[iS]; auto &iSGs = newM.grids[iS]; auto &oSGs = newM.grids[oS]; newM.nActive[oS] = FullConvolution_InputSgsToRulesAndOutputSgs_OMP( iSGs, oSGs, rb, size.data(), stride.data(), inputSpatialSize.data(), outputSpatialSize.data()); } return rb; } template RuleBook &Metadata::getRandomizedStrideRuleBook( /*long*/ at::Tensor inputSpatialSize, /*long*/ at::Tensor outputSpatialSize, /*long*/ at::Tensor size, /*long*/ at::Tensor stride, bool openMP) { auto p = ThreeLongTensorsToPoint(inputSpatialSize, size, stride); auto &rb = ruleBooks[p]; if (rb.empty()) { auto iS = LongTensorToPoint(inputSpatialSize); auto oS = LongTensorToPoint(outputSpatialSize); auto &iSGs = grids[iS]; auto &oSGs = grids[oS]; nActive[oS] = #if defined(ENABLE_OPENMP) openMP ? RSR_InputSgsToRulesAndOutputSgs_OMP( iSGs, oSGs, rb, size.data(), stride.data(), inputSpatialSize.data(), outputSpatialSize.data(), re) : #endif RSR_InputSgsToRulesAndOutputSgs(iSGs, oSGs, rb, size.data(), stride.data(), inputSpatialSize.data(), outputSpatialSize.data(), re); } return rb; } template std::vector Metadata::compareSparseHelper(Metadata &mR, /* long */ at::Tensor spatialSize) { auto p = LongTensorToPoint(spatialSize); auto &sgsL = grids[p]; auto &sgsR = mR.grids[p]; std::vector cL, cR, L, R; for (Int sample = 0; sample < (Int)sgsL.size(); ++sample) { auto &sgL = sgsL[sample]; auto &sgR = sgsR[sample]; for (auto const &iter : sgL.mp) { if (sgR.mp.find(iter.first) == sgR.mp.end()) { L.push_back(sgL.mp[iter.first] + sgL.ctr); } else { cL.push_back(sgL.mp[iter.first] + sgL.ctr); cR.push_back(sgR.mp[iter.first] + sgR.ctr); } } for (auto const &iter : sgR.mp) { if (sgL.mp.find(iter.first) == sgL.mp.end()) { R.push_back(sgR.mp[iter.first] + sgR.ctr); } } } at::Tensor cL_ = torch::empty({(long)cL.size()}, at::CPU(at::kLong)); std::memcpy(cL_.data(), &cL[0], cL.size() * sizeof(long)); at::Tensor cR_ = torch::empty({(long)cR.size()}, at::CPU(at::kLong)); std::memcpy(cR_.data(), &cR[0], cR.size() * sizeof(long)); at::Tensor L_ = torch::empty({(long)L.size()}, at::CPU(at::kLong)); std::memcpy(L_.data(), &L[0], L.size() * sizeof(long)); at::Tensor R_ = torch::empty({(long)R.size()}, at::CPU(at::kLong)); std::memcpy(R_.data(), &R[0], R.size() * sizeof(long)); return {cL_, cR_, L_, R_}; } template Int volume(long *point) { Int v = 1; for (Int i = 0; i < dimension; i++) v *= point[i]; return v; }