Unverified Commit 3c33f367 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Per file C++ Operator registration (#3135)

* Moving deform_conv2d op registration.

* Moving nms op registration.

* Moving new_empty_tensor op registration.

* Moving ps_roi_align op registration.

* Moving ps_roi_pool op registration.

* Moving roi_align op registration.

* Moving roi_pool op registration.

* Restoring headers for forward/backward and fixing styles.

* Restoring the test hack on windows.

* Stricter header inclusion.
parent 6cb4fc21
#pragma once
#include <ATen/ATen.h>
#include "../macros.h"
namespace vision {
namespace ops {
VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
VISION_API at::Tensor ps_roi_pool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
} // namespace ops
} // namespace vision
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <THC/THCAtomics.cuh>
#include "cuda_helpers.h"
#include "roi_align_kernel.h"
namespace vision {
namespace ops {
......@@ -314,9 +315,7 @@ __global__ void roi_align_backward_kernel_impl(
} // CUDA_1D_KERNEL_LOOP
}
} // namespace
at::Tensor roi_align_forward_cuda(
at::Tensor roi_align_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
......@@ -330,7 +329,7 @@ at::Tensor roi_align_forward_cuda(
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_forward_cuda";
at::CheckedFrom c = "roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
......@@ -359,7 +358,7 @@ at::Tensor roi_align_forward_cuda(
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_align_forward_cuda", [&] {
input.scalar_type(), "roi_align_forward_kernel", [&] {
roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input_.data_ptr<scalar_t>(),
......@@ -378,7 +377,7 @@ at::Tensor roi_align_forward_cuda(
return output;
}
at::Tensor roi_align_backward_cuda(
at::Tensor roi_align_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
......@@ -395,7 +394,7 @@ at::Tensor roi_align_backward_cuda(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_backward_cuda";
at::CheckedFrom c = "roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});
......@@ -424,7 +423,7 @@ at::Tensor roi_align_backward_cuda(
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "roi_align_backward_cuda", [&] {
grad.scalar_type(), "roi_align_backward_kernel", [&] {
roi_align_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.data_ptr<scalar_t>(),
......@@ -447,5 +446,12 @@ at::Tensor roi_align_backward_cuda(
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", roi_align_forward_kernel);
m.impl("_roi_align_backward", roi_align_backward_kernel);
}
} // namespace ops
} // namespace vision
#pragma once
#include <ATen/ATen.h>
#include "../macros.h"
namespace vision {
namespace ops {
VISION_API at::Tensor roi_align_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned);
VISION_API at::Tensor roi_align_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned);
} // namespace ops
} // namespace vision
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <torch/library.h>
#include <THC/THCAtomics.cuh>
#include "cuda_helpers.h"
#include "roi_pool_kernel.h"
namespace vision {
namespace ops {
......@@ -120,9 +121,7 @@ __global__ void roi_pool_backward_kernel_impl(
}
}
} // namespace
std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cuda(
std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
......@@ -135,7 +134,7 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cuda(
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_pool_forward_cuda";
at::CheckedFrom c = "roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
......@@ -167,7 +166,7 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cuda(
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_pool_forward_cuda", [&] {
input.scalar_type(), "roi_pool_forward_kernel", [&] {
roi_pool_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input_.data_ptr<scalar_t>(),
......@@ -185,7 +184,7 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cuda(
return std::make_tuple(output, argmax);
}
at::Tensor roi_pool_backward_cuda(
at::Tensor roi_pool_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
......@@ -204,7 +203,7 @@ at::Tensor roi_pool_backward_cuda(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
argmax_t{argmax, "argmax", 3};
at::CheckedFrom c = "roi_pool_backward_cuda";
at::CheckedFrom c = "roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
at::checkAllSameType(c, {grad_t, rois_t});
......@@ -235,7 +234,7 @@ at::Tensor roi_pool_backward_cuda(
auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "roi_pool_backward_cuda", [&] {
grad.scalar_type(), "roi_pool_backward_kernel", [&] {
roi_pool_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.data_ptr<scalar_t>(),
......@@ -258,5 +257,12 @@ at::Tensor roi_pool_backward_cuda(
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_pool", roi_pool_forward_kernel);
m.impl("_roi_pool_backward", roi_pool_backward_kernel);
}
} // namespace ops
} // namespace vision
#pragma once
#include <ATen/ATen.h>
#include "../macros.h"
namespace vision {
namespace ops {
VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
VISION_API at::Tensor roi_pool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
} // namespace ops
} // namespace vision
#include "deform_conv2d.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
......@@ -77,6 +79,10 @@ at::Tensor deform_conv2d_autocast(
use_mask)
.to(input.scalar_type());
}
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
}
#endif
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
......@@ -118,6 +124,13 @@ _deform_conv2d_backward(
use_mask);
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
}
namespace {
class DeformConv2dFunction
......@@ -365,5 +378,10 @@ deform_conv2d_backward_autograd(
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}
} // namespace ops
} // namespace vision
#pragma once
#include "cpu/deform_conv2d_kernel.h"
#ifdef WITH_CUDA
#include "cuda/deform_conv2d_kernel.h"
#endif
#ifdef WITH_HIP
#include "hip/deform_conv2d_kernel.h"
#endif
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
// C++ Forward
at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);
// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor deform_conv2d_autocast(
VISION_API at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
......@@ -46,9 +22,9 @@ at::Tensor deform_conv2d_autocast(
int64_t groups,
int64_t offset_groups,
bool use_mask);
#endif
// C++ Backward
VISION_API
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
......@@ -67,40 +43,5 @@ _deform_conv2d_backward(
int64_t offset_groups,
bool use_mask);
// Autograd Forward and Backward
at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);
} // namespace ops
} // namespace vision
#include "new_empty_tensor_op.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
namespace ops {
......@@ -35,5 +37,9 @@ at::Tensor new_empty_tensor(
return NewEmptyTensorOp::apply(input, shape)[0];
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("_new_empty_tensor_op", &new_empty_tensor);
}
} // namespace ops
} // namespace vision
#pragma once
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
at::Tensor new_empty_tensor(
VISION_API at::Tensor new_empty_tensor(
const at::Tensor& input,
const c10::List<int64_t>& shape);
......
#include "nms.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
......@@ -29,7 +31,15 @@ at::Tensor nms_autocast(
at::autocast::cached_cast(at::kFloat, scores),
iou_threshold);
}
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("nms", nms_autocast);
}
#endif
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
}
} // namespace ops
} // namespace vision
#pragma once
#include "cpu/nms_kernel.h"
#ifdef WITH_CUDA
#include "cuda/nms_kernel.h"
#endif
#ifdef WITH_HIP
#include "hip/nms_kernel.h"
#endif
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
// C++ Forward
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);
// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor nms_autocast(
VISION_API at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);
#endif
} // namespace ops
} // namespace vision
#include "ps_roi_align.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
......@@ -43,6 +45,10 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
}
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_align", ps_roi_align_autocast);
}
#endif
at::Tensor _ps_roi_align_backward(
......@@ -75,6 +81,13 @@ at::Tensor _ps_roi_align_backward(
width);
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
m.def(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
}
namespace {
class PSROIAlignFunction
......@@ -222,5 +235,10 @@ at::Tensor ps_roi_align_backward_autograd(
width)[0];
}
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
}
} // namespace ops
} // namespace vision
#pragma once
#include "cpu/ps_roi_align_kernel.h"
#ifdef WITH_CUDA
#include "cuda/ps_roi_align_kernel.h"
#endif
#ifdef WITH_HIP
#include "hip/ps_roi_align_kernel.h"
#endif
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
// C++ Forward
std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);
// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);
#endif
// C++ Backward
at::Tensor _ps_roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
// Autograd Forward and Backward
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);
at::Tensor ps_roi_align_backward_autograd(
VISION_API at::Tensor _ps_roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
......
#include "ps_roi_pool.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
......@@ -39,6 +41,10 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
}
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_pool", ps_roi_pool_autocast);
}
#endif
at::Tensor _ps_roi_pool_backward(
......@@ -69,6 +75,13 @@ at::Tensor _ps_roi_pool_backward(
width);
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
}
namespace {
class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
......@@ -201,5 +214,10 @@ at::Tensor ps_roi_pool_backward_autograd(
width)[0];
}
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
}
} // namespace ops
} // namespace vision
#pragma once
#include "cpu/ps_roi_pool_kernel.h"
#ifdef WITH_CUDA
#include "cuda/ps_roi_pool_kernel.h"
#endif
#ifdef WITH_HIP
#include "hip/ps_roi_pool_kernel.h"
#endif
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
// C++ Forward
std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
#endif
// C++ Backward
at::Tensor _ps_roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
// Autograd Forward and Backward
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
at::Tensor ps_roi_pool_backward_autograd(
VISION_API at::Tensor _ps_roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
......
#include "roi_align.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
......@@ -52,6 +54,10 @@ at::Tensor roi_align_autocast(
aligned)
.to(input.scalar_type());
}
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", roi_align_autocast);
}
#endif
at::Tensor _roi_align_backward(
......@@ -84,6 +90,13 @@ at::Tensor _roi_align_backward(
aligned);
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
}
namespace {
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
......@@ -231,5 +244,10 @@ at::Tensor roi_align_backward_autograd(
aligned)[0];
}
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
}
} // namespace ops
} // namespace vision
#pragma once
#include "cpu/roi_align_kernel.h"
#ifdef WITH_CUDA
#include "cuda/roi_align_kernel.h"
#endif
#ifdef WITH_HIP
#include "hip/roi_align_kernel.h"
#endif
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
// C++ Forward
at::Tensor roi_align(
VISION_API at::Tensor roi_align(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
......@@ -22,43 +16,8 @@ at::Tensor roi_align(
int64_t sampling_ratio,
bool aligned);
// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned);
#endif
// C++ Backward
at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned);
// Autograd Forward and Backward
at::Tensor roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned);
at::Tensor roi_align_backward_autograd(
VISION_API at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
......
#include "roi_pool.h"
#include <torch/extension.h>
#include <torch/autograd.h>
#include <torch/types.h>
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
......@@ -39,6 +41,10 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
}
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_pool", roi_pool_autocast);
}
#endif
at::Tensor _roi_pool_backward(
......@@ -68,6 +74,13 @@ at::Tensor _roi_pool_backward(
width);
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
}
namespace {
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
......@@ -200,5 +213,10 @@ at::Tensor roi_pool_backward_autograd(
width)[0];
}
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
}
} // namespace ops
} // namespace vision
#pragma once
#include "cpu/roi_pool_kernel.h"
#ifdef WITH_CUDA
#include "cuda/roi_pool_kernel.h"
#endif
#ifdef WITH_HIP
#include "hip/roi_pool_kernel.h"
#endif
#include <ATen/ATen.h>
#include "macros.h"
namespace vision {
namespace ops {
// C++ Forward
std::tuple<at::Tensor, at::Tensor> roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
// Autocast Forward
#if defined(WITH_CUDA) || defined(WITH_HIP)
std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
#endif
// C++ Backward
at::Tensor _roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
// Autograd Forward and Backward
std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
at::Tensor roi_pool_backward_autograd(
VISION_API at::Tensor _roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
......
#include "vision.h"
#include <Python.h>
#include <torch/script.h>
#include <torch/library.h>
#ifdef WITH_CUDA
#include <cuda.h>
......@@ -10,14 +10,6 @@
#include <hip/hip_runtime.h>
#endif
#include "deform_conv2d.h"
#include "new_empty_tensor_op.h"
#include "nms.h"
#include "ps_roi_align.h"
#include "ps_roi_pool.h"
#include "roi_align.h"
#include "roi_pool.h"
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
#ifdef _WIN32
......@@ -35,88 +27,8 @@ int64_t cuda_version() {
return -1;
#endif
}
} // namespace vision
using namespace vision::ops;
TORCH_LIBRARY(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
m.def(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
m.def(
"ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def("_cuda_version", &vision::cuda_version);
m.def("_new_empty_tensor_op", &new_empty_tensor);
}
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
m.impl("nms", nms_cpu);
m.impl("ps_roi_align", ps_roi_align_forward_cpu);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_cpu);
m.impl("ps_roi_pool", ps_roi_pool_forward_cpu);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cpu);
m.impl("roi_align", roi_align_forward_cpu);
m.impl("_roi_align_backward", roi_align_backward_cpu);
m.impl("roi_pool", roi_pool_forward_cpu);
m.impl("_roi_pool_backward", roi_pool_backward_cpu);
}
// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
m.impl("nms", nms_cuda);
m.impl("ps_roi_align", ps_roi_align_forward_cuda);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_cuda);
m.impl("ps_roi_pool", ps_roi_pool_forward_cuda);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cuda);
m.impl("roi_align", roi_align_forward_cuda);
m.impl("_roi_align_backward", roi_align_backward_cuda);
m.impl("roi_pool", roi_pool_forward_cuda);
m.impl("_roi_pool_backward", roi_pool_backward_cuda);
}
#endif
// Autocast only needs to wrap forward pass ops.
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
m.impl("nms", nms_autocast);
m.impl("ps_roi_align", ps_roi_align_autocast);
m.impl("ps_roi_pool", ps_roi_pool_autocast);
m.impl("roi_align", roi_align_autocast);
m.impl("roi_pool", roi_pool_autocast);
}
#endif
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("_cuda_version", &cuda_version);
}
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment