Commit 7f9d2eb5 authored by Kai Chen's avatar Kai Chen
Browse files

fix extension to fit pytorch 0.4.1 api

parent 0e0b9246
......@@ -17,9 +17,9 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_height, const int pooled_width,
at::Tensor bottom_grad);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERT(x.is_contiguous(), #x " must be contiguous ")
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
......
......@@ -16,9 +16,9 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int num_rois, const int pooled_h,
const int pooled_w, at::Tensor bottom_grad);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERT(x.is_contiguous(), #x " must be contiguous ")
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment