Unverified Commit c959dab8 authored by Krishna Murthy's avatar Krishna Murthy Committed by GitHub
Browse files

Make compatible with pytorch 1.11 and newer; bugfix (#418)



* Bugfix: missing include common.h
Signed-off-by: default avatarKrishna Murthy <krrish94@gmail.com>

* Make compatible with pytorch 1.11 and newer -- THCudaCheck() deprecated
Signed-off-by: default avatarKrishna Murthy <krrish94@gmail.com>

---------
Signed-off-by: default avatarKrishna Murthy <krrish94@gmail.com>
parent 331ecdd5
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include <THC/THC.h>
#include <cuda.h> #include <cuda.h>
...@@ -378,14 +376,14 @@ std::vector<at::Tensor> box_encoder(const int N_img, ...@@ -378,14 +376,14 @@ std::vector<at::Tensor> box_encoder(const int N_img,
printf("allocating %lu bytes for output labels\n", N_img*M*sizeof(long)); printf("allocating %lu bytes for output labels\n", N_img*M*sizeof(long));
#endif #endif
at::Tensor labels_out = at::empty({N_img * M}, labels_input.options()); at::Tensor labels_out = at::empty({N_img * M}, labels_input.options());
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
// copy default boxes to outputs // copy default boxes to outputs
#ifdef DEBUG #ifdef DEBUG
printf("allocating %lu bytes for output bboxes\n", N_img*M*4*sizeof(float)); printf("allocating %lu bytes for output bboxes\n", N_img*M*4*sizeof(float));
#endif #endif
at::Tensor bbox_out = dbox.repeat({N_img, 1}); at::Tensor bbox_out = dbox.repeat({N_img, 1});
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
// need to allocate some workspace // need to allocate some workspace
#ifdef DEBUG #ifdef DEBUG
...@@ -393,7 +391,7 @@ std::vector<at::Tensor> box_encoder(const int N_img, ...@@ -393,7 +391,7 @@ std::vector<at::Tensor> box_encoder(const int N_img,
#endif #endif
// at::Tensor workspace = at::CUDA(at::kByte).zeros({8 * M * N_img}); // at::Tensor workspace = at::CUDA(at::kByte).zeros({8 * M * N_img});
at::Tensor workspace = at::zeros({8 * M * N_img}, at::CUDA(at::kByte)); at::Tensor workspace = at::zeros({8 * M * N_img}, at::CUDA(at::kByte));
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
// Encode the inputs // Encode the inputs
const int THREADS_PER_BLOCK = 256; const int THREADS_PER_BLOCK = 256;
...@@ -408,7 +406,7 @@ std::vector<at::Tensor> box_encoder(const int N_img, ...@@ -408,7 +406,7 @@ std::vector<at::Tensor> box_encoder(const int N_img,
(float4*)bbox_out.data_ptr<float>(), (float4*)bbox_out.data_ptr<float>(),
labels_out.data_ptr<long>()); labels_out.data_ptr<long>());
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
return {bbox_out, labels_out}; return {bbox_out, labels_out};
} }
...@@ -435,7 +433,7 @@ at::Tensor calc_ious(const int N_img, ...@@ -435,7 +433,7 @@ at::Tensor calc_ious(const int N_img,
(float4*)boxes2.data_ptr<float>(), (float4*)boxes2.data_ptr<float>(),
ious.data_ptr<float>()); ious.data_ptr<float>());
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
return ious; return ious;
} }
...@@ -569,7 +567,7 @@ std::vector<at::Tensor> random_horiz_flip( ...@@ -569,7 +567,7 @@ std::vector<at::Tensor> random_horiz_flip(
flip.data_ptr<float>(), flip.data_ptr<float>(),
tmp_img.data_ptr<scalar_t>(), tmp_img.data_ptr<scalar_t>(),
nhwc); nhwc);
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
}); });
// copy tmp_img -> img // copy tmp_img -> img
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h> #include <ATen/cuda/detail/KernelUtils.h>
#include "common.h"
template <typename dest_t, typename src_t> template <typename dest_t, typename src_t>
static inline dest_t safe_downcast(src_t v) static inline dest_t safe_downcast(src_t v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment