// Copyright 2019 Yan Yan // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include namespace spconv { namespace functor { template struct CreateConvIndicePairFunctor { Index operator()(const tv::CPU& d, tv::TensorView indicesIn, tv::TensorView indicesOut, tv::TensorView gridsOut, tv::TensorView indicePairs, tv::TensorView indiceNum, const tv::SimpleVector kernelSize, const tv::SimpleVector stride, const tv::SimpleVector padding, const tv::SimpleVector dilation, const tv::SimpleVector outSpatialShape, bool transpose, bool resetGrid) { if (transpose) return getIndicePairsDeConv( indicesIn, indicesOut, gridsOut, indicePairs, indiceNum, kernelSize.data(), stride.data(), padding.data(), dilation.data(), outSpatialShape.data()); else return getIndicePairsConv( indicesIn, indicesOut, gridsOut, indicePairs, indiceNum, kernelSize.data(), stride.data(), padding.data(), dilation.data(), outSpatialShape.data()); } }; template struct CreateSubMIndicePairFunctor { Index operator()(const tv::CPU& d, tv::TensorView indicesIn, tv::TensorView gridsOut, tv::TensorView indicePairs, tv::TensorView indiceNum, const tv::SimpleVector kernelSize, const tv::SimpleVector stride, const tv::SimpleVector padding, const tv::SimpleVector dilation, const tv::SimpleVector outSpatialShape, bool transpose, bool resetGrid) { return getIndicePairsSubM( indicesIn, gridsOut, indicePairs, indiceNum, kernelSize.data(), stride.data(), padding.data(), dilation.data(), outSpatialShape.data()); } }; } // namespace functor #define DECLARE_CPU_SPECS_INDEX_NDIM(Index, NDIM) \ template struct functor::CreateConvIndicePairFunctor; \ template struct functor::CreateSubMIndicePairFunctor; #define DECLARE_CPU_INDEX(Index) \ DECLARE_CPU_SPECS_INDEX_NDIM(Index, 1); \ DECLARE_CPU_SPECS_INDEX_NDIM(Index, 2); \ DECLARE_CPU_SPECS_INDEX_NDIM(Index, 3); \ DECLARE_CPU_SPECS_INDEX_NDIM(Index, 4); DECLARE_CPU_INDEX(int); DECLARE_CPU_INDEX(long); #undef DECLARE_CPU_INDEX #undef DECLARE_CPU_SPECS_INDEX_NDIM } // namespace spconv