Unverified Commit 9ff29e24 authored by Liu Yihua's avatar Liu Yihua Committed by GitHub
Browse files

feat: support torch>=1.11 (#1041)

* feat: support torch>=1.11

Fix #900.
Support PyTorch version >= 1.11. Referring to https://github.com/pytorch/pytorch/pull/66765 and https://github.com/pytorch/pytorch/wiki/TH-to-ATen-porting-guide.

* fix: Remove preproc torch version check macros
parent c233477a
......@@ -7,13 +7,10 @@ All Rights Reserved 2018.
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
......@@ -9,11 +9,8 @@ All Rights Reserved 2018.
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h"
extern THCState *state;
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
......
......@@ -7,7 +7,6 @@ All Rights Reserved 2018.
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
......@@ -15,8 +14,6 @@ All Rights Reserved 2018.
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"
extern THCState *state;
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
......@@ -43,6 +40,7 @@ void three_interpolate_wrapper_fast(int b, int c, int m, int n,
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out);
}
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
......
......@@ -8,12 +8,8 @@ All Rights Reserved 2018.
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include "sampling_gpu.h"
extern THCState *state;
int gather_points_wrapper_fast(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
......
......@@ -7,13 +7,10 @@ All Rights Reserved 2019-2020.
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......@@ -28,6 +25,7 @@ extern THCState *state;
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int ball_query_wrapper_stack(int B, int M, float radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor) {
......
......@@ -9,10 +9,8 @@ All Rights Reserved 2019-2020.
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
......@@ -7,7 +7,6 @@ All Rights Reserved 2019-2020.
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
......@@ -15,8 +14,6 @@ All Rights Reserved 2019-2020.
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include "sampling_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
......@@ -10,13 +10,10 @@ All Rights Reserved 2020.
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "vector_pool_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
......@@ -8,8 +7,6 @@
#include <cuda_runtime_api.h>
#include "voxel_query_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment