Unverified Commit 903d3b55 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

support pytorch<=1.10

fix for #112
parent 0a091831
...@@ -7,7 +7,15 @@ ...@@ -7,7 +7,15 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/cub_definitions.cuh> // #include <ATen/cuda/cub_definitions.cuh>
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
...@@ -31,4 +39,4 @@ ...@@ -31,4 +39,4 @@
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \ AT_CUDA_CHECK(cudaGetLastError()); \
} while (false) } while (false)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment