Unverified Commit 3711754a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add Torch Selective macros in all C++ Ops for better support on mobile (#3218)

* Adding TORCH_SELECTIVE_* macros on op registration.

* Adding torchvision namespace.
parent 4d2d8bb0
...@@ -276,8 +276,12 @@ at::Tensor ps_roi_pool_backward_kernel( ...@@ -276,8 +276,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel); m.impl(
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -449,8 +449,12 @@ at::Tensor roi_align_backward_kernel( ...@@ -449,8 +449,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", roi_align_forward_kernel); m.impl(
m.impl("_roi_align_backward", roi_align_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -260,8 +260,12 @@ at::Tensor roi_pool_backward_kernel( ...@@ -260,8 +260,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_pool", roi_pool_forward_kernel); m.impl(
m.impl("_roi_pool_backward", roi_pool_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -84,10 +84,10 @@ _deform_conv2d_backward( ...@@ -84,10 +84,10 @@ _deform_conv2d_backward(
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"); "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
} }
} // namespace ops } // namespace ops
......
...@@ -16,7 +16,8 @@ at::Tensor nms( ...@@ -16,7 +16,8 @@ at::Tensor nms(
} }
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -54,10 +54,10 @@ at::Tensor _ps_roi_align_backward( ...@@ -54,10 +54,10 @@ at::Tensor _ps_roi_align_backward(
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -50,10 +50,10 @@ at::Tensor _ps_roi_pool_backward( ...@@ -50,10 +50,10 @@ at::Tensor _ps_roi_pool_backward(
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -64,10 +64,10 @@ at::Tensor _roi_align_backward( ...@@ -64,10 +64,10 @@ at::Tensor _roi_align_backward(
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -49,10 +49,10 @@ at::Tensor _roi_pool_backward( ...@@ -49,10 +49,10 @@ at::Tensor _roi_pool_backward(
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
m.def( m.def(TORCH_SELECTIVE_SCHEMA(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment