Unverified Commit 9ff29e24 authored by Liu Yihua's avatar Liu Yihua Committed by GitHub
Browse files

feat: support torch>=1.11 (#1041)

* feat: support torch>=1.11

Fix #900.
Support PyTorch version >= 1.11. Referring to https://github.com/pytorch/pytorch/pull/66765 and https://github.com/pytorch/pytorch/wiki/TH-to-ATen-porting-guide.

* fix: Remove preproc torch version check macros
parent c233477a
...@@ -7,13 +7,10 @@ All Rights Reserved 2018. ...@@ -7,13 +7,10 @@ All Rights Reserved 2018.
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "ball_query_gpu.h" #include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
...@@ -9,11 +9,8 @@ All Rights Reserved 2018. ...@@ -9,11 +9,8 @@ All Rights Reserved 2018.
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h" #include "group_points_gpu.h"
extern THCState *state;
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
......
...@@ -7,7 +7,6 @@ All Rights Reserved 2018. ...@@ -7,7 +7,6 @@ All Rights Reserved 2018.
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
...@@ -15,8 +14,6 @@ All Rights Reserved 2018. ...@@ -15,8 +14,6 @@ All Rights Reserved 2018.
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "interpolate_gpu.h" #include "interpolate_gpu.h"
extern THCState *state;
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
...@@ -43,6 +40,7 @@ void three_interpolate_wrapper_fast(int b, int c, int m, int n, ...@@ -43,6 +40,7 @@ void three_interpolate_wrapper_fast(int b, int c, int m, int n,
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out); three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out);
} }
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
at::Tensor grad_out_tensor, at::Tensor grad_out_tensor,
at::Tensor idx_tensor, at::Tensor idx_tensor,
......
...@@ -8,12 +8,8 @@ All Rights Reserved 2018. ...@@ -8,12 +8,8 @@ All Rights Reserved 2018.
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include "sampling_gpu.h" #include "sampling_gpu.h"
extern THCState *state;
int gather_points_wrapper_fast(int b, int c, int n, int npoints, int gather_points_wrapper_fast(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
......
...@@ -7,13 +7,10 @@ All Rights Reserved 2019-2020. ...@@ -7,13 +7,10 @@ All Rights Reserved 2019-2020.
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "ball_query_gpu.h" #include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
...@@ -28,6 +25,7 @@ extern THCState *state; ...@@ -28,6 +25,7 @@ extern THCState *state;
} while (0) } while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int ball_query_wrapper_stack(int B, int M, float radius, int nsample, int ball_query_wrapper_stack(int B, int M, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor) { at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor) {
......
...@@ -9,10 +9,8 @@ All Rights Reserved 2019-2020. ...@@ -9,10 +9,8 @@ All Rights Reserved 2019-2020.
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h" #include "group_points_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
...@@ -7,7 +7,6 @@ All Rights Reserved 2019-2020. ...@@ -7,7 +7,6 @@ All Rights Reserved 2019-2020.
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
...@@ -15,8 +14,6 @@ All Rights Reserved 2019-2020. ...@@ -15,8 +14,6 @@ All Rights Reserved 2019-2020.
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "interpolate_gpu.h" #include "interpolate_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include "sampling_gpu.h" #include "sampling_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
...@@ -10,13 +10,10 @@ All Rights Reserved 2020. ...@@ -10,13 +10,10 @@ All Rights Reserved 2020.
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "vector_pool_gpu.h" #include "vector_pool_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <vector> #include <vector>
#include <THC/THC.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
...@@ -8,8 +7,6 @@ ...@@ -8,8 +7,6 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "voxel_query_gpu.h" #include "voxel_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \ #define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \ if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment