Commit 1df084f2 authored by traveller59's avatar traveller59
Browse files

fix problem with torch 1.4

parent 6e727bcd
...@@ -37,12 +37,12 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil ...@@ -37,12 +37,12 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter = std::max_element( auto indicePairMaxSizeIter = std::max_element(
indicePairNumCpu.data<int>(), indicePairNumCpu.data<int>() + kernelVolume); indicePairNumCpu.data_ptr<int>(), indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumCpu.data<int>(); int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
/*if (_subM){ /*if (_subM){
std::vector<int> indicePairNumVec(indicePairNumCpu.data<int>(), indicePairNumCpu.data<int>() + kernelVolume); std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(), indicePairNumCpu.data_ptr<int>() + kernelVolume);
indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset); indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset);
auto indicePairVecMaxSizeIter = std::max_element( auto indicePairVecMaxSizeIter = std::max_element(
...@@ -68,15 +68,15 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil ...@@ -68,15 +68,15 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
double totalGEMMTime = 0; double totalGEMMTime = 0;
double totalSAddTime = 0; double totalSAddTime = 0;
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
auto outputBufferBlob = auto outputBufferBlob =
torch::from_blob(outputBuffer.data<T>(), {nHot, numOutPlanes}, options); torch::from_blob(outputBuffer.data_ptr<T>(), {nHot, numOutPlanes}, options);
auto inputBufferBlob = auto inputBufferBlob =
torch::from_blob(inputBuffer.data<T>(), {nHot, numInPlanes}, options); torch::from_blob(inputBuffer.data_ptr<T>(), {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor; functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor;
......
...@@ -37,7 +37,7 @@ torch::Tensor pointPillarScatter(torch::Tensor features, torch::Tensor coors, ...@@ -37,7 +37,7 @@ torch::Tensor pointPillarScatter(torch::Tensor features, torch::Tensor coors,
tv::check_torch_dtype<int>(shape); tv::check_torch_dtype<int>(shape);
tv::check_torch_dtype<T>(coors); tv::check_torch_dtype<T>(coors);
auto shapeData = shape.data<int>(); auto shapeData = shape.data_ptr<int>();
torch::Tensor canvas = torch::Tensor canvas =
torch::zeros({shapeData[0], shapeData[1], shapeData[2], shapeData[3]}, torch::zeros({shapeData[0], shapeData[1], shapeData[2], shapeData[3]},
features.options()); features.options());
......
...@@ -33,7 +33,7 @@ torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs, ...@@ -33,7 +33,7 @@ torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
torch::Tensor output = torch::zeros({numAct, numInPlanes}, options); torch::Tensor output = torch::zeros({numAct, numInPlanes}, options);
double totalTime = 0; double totalTime = 0;
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) { if (nHot <= 0) {
continue; continue;
} }
...@@ -75,7 +75,7 @@ torch::Tensor indiceMaxPoolBackward(torch::Tensor features, ...@@ -75,7 +75,7 @@ torch::Tensor indiceMaxPoolBackward(torch::Tensor features,
torch::Tensor inputGrad = torch::zeros(features.sizes(), options); torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
auto kernelVolume = indicePairs.size(0); auto kernelVolume = indicePairs.size(0);
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) { if (nHot <= 0) {
continue; continue;
} }
......
...@@ -341,15 +341,15 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -341,15 +341,15 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter = auto indicePairMaxSizeIter =
std::max_element(indicePairNumCpu.data<int>(), std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data<int>() + kernelVolume); indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset = int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data<int>(); indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
/*if (_subM){ /*if (_subM){
std::vector<int> indicePairNumVec(indicePairNumCpu.data<int>(), std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data<int>() + kernelVolume); indicePairNumCpu.data_ptr<int>() + kernelVolume);
indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset); indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset);
auto indicePairVecMaxSizeIter = std::max_element( auto indicePairVecMaxSizeIter = std::max_element(
...@@ -376,15 +376,15 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters, ...@@ -376,15 +376,15 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
double totalGEMMTime = 0; double totalGEMMTime = 0;
double totalSAddTime = 0; double totalSAddTime = 0;
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
auto outputBufferBlob = auto outputBufferBlob =
torch::from_blob(outputBuffer.data<T>(), {nHot, numOutPlanes}, options); torch::from_blob(outputBuffer.data_ptr<T>(), {nHot, numOutPlanes}, options);
auto inputBufferBlob = auto inputBufferBlob =
torch::from_blob(inputBuffer.data<T>(), {nHot, numInPlanes}, options); torch::from_blob(inputBuffer.data_ptr<T>(), {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor; functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor;
...@@ -460,10 +460,10 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -460,10 +460,10 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter = auto indicePairMaxSizeIter =
std::max_element(indicePairNumCpu.data<int>(), std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data<int>() + kernelVolume); indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset = int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data<int>(); indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
auto options = auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device()); torch::TensorOptions().dtype(features.dtype()).device(features.device());
...@@ -483,7 +483,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -483,7 +483,7 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch::mm_out(inputGrad, outGrad, filters[indicePairMaxOffset].t()); torch::mm_out(inputGrad, outGrad, filters[indicePairMaxOffset].t());
} }
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
...@@ -521,9 +521,9 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters, ...@@ -521,9 +521,9 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
auto filterGradSub = filtersGrad[i]; auto filterGradSub = filtersGrad[i];
auto outputBufferBlob = auto outputBufferBlob =
torch::from_blob(outputBuffer.data<T>(), {nHot, numOutPlanes}, options); torch::from_blob(outputBuffer.data_ptr<T>(), {nHot, numOutPlanes}, options);
auto inputBufferBlob = auto inputBufferBlob =
torch::from_blob(inputBuffer.data<T>(), {nHot, numInPlanes}, options); torch::from_blob(inputBuffer.data_ptr<T>(), {nHot, numInPlanes}, options);
torch::mm_out(filterGradSub, inputBufferBlob.t(), outputBufferBlob); torch::mm_out(filterGradSub, inputBufferBlob.t(), outputBufferBlob);
torch::mm_out(inputBufferBlob, outputBufferBlob, filters[i].t()); torch::mm_out(inputBufferBlob, outputBufferBlob, filters[i].t());
...@@ -566,15 +566,15 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters, ...@@ -566,15 +566,15 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters,
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto totalActsTen = indicePairNumCpu.sum(); auto totalActsTen = indicePairNumCpu.sum();
auto totalActs = indicePairNumCpu.data<int>()[0]; auto totalActs = indicePairNumCpu.data_ptr<int>()[0];
auto indicePairMaxSizeIter = auto indicePairMaxSizeIter =
std::max_element(indicePairNumCpu.data<int>(), std::max_element(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data<int>() + kernelVolume); indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset = int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data<int>(); indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
std::vector<int> indicePairNumVec(indicePairNumCpu.data<int>(), std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data<int>() + indicePairNumCpu.data_ptr<int>() +
kernelVolume); kernelVolume);
indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset); indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset);
int subRuleMaxSize = int subRuleMaxSize =
...@@ -604,14 +604,14 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters, ...@@ -604,14 +604,14 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters,
double totalSAddTime = 0; double totalSAddTime = 0;
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
// //
auto outputBufferBlob = torch::from_blob(outputBuffer[i].data<T>(), auto outputBufferBlob = torch::from_blob(outputBuffer[i].data_ptr<T>(),
{nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = torch::from_blob(inputBuffer[i].data<T>(), auto inputBufferBlob = torch::from_blob(inputBuffer[i].data_ptr<T>(),
{nHot, numInPlanes}, options); {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor; functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor;
...@@ -642,13 +642,13 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters, ...@@ -642,13 +642,13 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters,
} }
// totalGatherTime += timer.report() / 1000.0; // totalGatherTime += timer.report() / 1000.0;
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
auto outputBufferBlob = torch::from_blob(outputBuffer[i].data<T>(), auto outputBufferBlob = torch::from_blob(outputBuffer[i].data_ptr<T>(),
{nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = torch::from_blob(inputBuffer[i].data<T>(), auto inputBufferBlob = torch::from_blob(inputBuffer[i].data_ptr<T>(),
{nHot, numInPlanes}, options); {nHot, numInPlanes}, options);
torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]); torch::mm_out(outputBufferBlob, inputBufferBlob, filters[i]);
...@@ -656,13 +656,13 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters, ...@@ -656,13 +656,13 @@ indiceConvDevelopDontUse(torch::Tensor features, torch::Tensor filters,
// totalGEMMTime += timer.report() / 1000.0; // totalGEMMTime += timer.report() / 1000.0;
// totalGEMMTime += timer.report() / 1000.0; // totalGEMMTime += timer.report() / 1000.0;
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data<int>()[i]; auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0 || (subM && i == indicePairMaxOffset)) { if (nHot <= 0 || (subM && i == indicePairMaxOffset)) {
continue; continue;
} }
auto outputBufferBlob = torch::from_blob(outputBuffer[i].data<T>(), auto outputBufferBlob = torch::from_blob(outputBuffer[i].data_ptr<T>(),
{nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = torch::from_blob(inputBuffer[i].data<T>(), auto inputBufferBlob = torch::from_blob(inputBuffer[i].data_ptr<T>(),
{nHot, numInPlanes}, options); {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
......
...@@ -93,6 +93,6 @@ tv::TensorView<T> torch2tv(const torch::Tensor &tensor) { ...@@ -93,6 +93,6 @@ tv::TensorView<T> torch2tv(const torch::Tensor &tensor) {
for (auto i : tensor.sizes()) { for (auto i : tensor.sizes()) {
shape.push_back(i); shape.push_back(i);
} }
return tv::TensorView<T>(tensor.data<std::remove_const_t<T>>(), shape); return tv::TensorView<T>(tensor.data_ptr<std::remove_const_t<T>>(), shape);
} }
} // namespace tv } // namespace tv
\ No newline at end of file
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <torch/script.h>
#include <spconv/pool_ops.h> #include <spconv/pool_ops.h>
#include <spconv/spconv_ops.h> #include <spconv/spconv_ops.h>
#include <spconv/pillar_scatter_ops.h> #include <spconv/pillar_scatter_ops.h>
...@@ -19,7 +20,8 @@ ...@@ -19,7 +20,8 @@
#include <spconv/nms_ops.h> #include <spconv/nms_ops.h>
static auto registry = static auto registry =
torch::jit::RegisterOperators("spconv::get_indice_pairs_2d", &spconv::getIndicePair<2>) torch::RegisterOperators()
.op("spconv::get_indice_pairs_2d", &spconv::getIndicePair<2>)
.op("spconv::get_indice_pairs_3d", &spconv::getIndicePair<3>) .op("spconv::get_indice_pairs_3d", &spconv::getIndicePair<3>)
.op("spconv::get_indice_pairs_4d", &spconv::getIndicePair<4>) .op("spconv::get_indice_pairs_4d", &spconv::getIndicePair<4>)
.op("spconv::get_indice_pairs_grid_2d", &spconv::getIndicePairPreGrid<2>) .op("spconv::get_indice_pairs_grid_2d", &spconv::getIndicePairPreGrid<2>)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment