Unverified Commit 3711754a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add Torch Selective macros in all C++ Ops for better support on mobile (#3218)

* Adding TORCH_SELECTIVE_* macros on op registration.

* Adding torchvision namespace.
parent 4d2d8bb0
...@@ -45,7 +45,9 @@ at::Tensor deform_conv2d_autocast( ...@@ -45,7 +45,9 @@ at::Tensor deform_conv2d_autocast(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast); m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autocast));
} }
} // namespace ops } // namespace ops
......
...@@ -22,7 +22,7 @@ at::Tensor nms_autocast( ...@@ -22,7 +22,7 @@ at::Tensor nms_autocast(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("nms", nms_autocast); m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
} }
} // namespace ops } // namespace ops
......
...@@ -32,7 +32,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast( ...@@ -32,7 +32,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_align", ps_roi_align_autocast); m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autocast));
} }
} // namespace ops } // namespace ops
......
...@@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast( ...@@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_pool", ps_roi_pool_autocast); m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autocast));
} }
} // namespace ops } // namespace ops
......
...@@ -31,7 +31,9 @@ at::Tensor roi_align_autocast( ...@@ -31,7 +31,9 @@ at::Tensor roi_align_autocast(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", roi_align_autocast); m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
} }
} // namespace ops } // namespace ops
......
...@@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast( ...@@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_pool", roi_pool_autocast); m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autocast));
} }
} // namespace ops } // namespace ops
......
...@@ -254,8 +254,12 @@ deform_conv2d_backward_autograd( ...@@ -254,8 +254,12 @@ deform_conv2d_backward_autograd(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd); m.impl(
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_autograd));
} }
} // namespace ops } // namespace ops
......
...@@ -154,8 +154,12 @@ at::Tensor ps_roi_align_backward_autograd( ...@@ -154,8 +154,12 @@ at::Tensor ps_roi_align_backward_autograd(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd); m.impl(
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd); TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_autograd));
} }
} // namespace ops } // namespace ops
......
...@@ -139,8 +139,12 @@ at::Tensor ps_roi_pool_backward_autograd( ...@@ -139,8 +139,12 @@ at::Tensor ps_roi_pool_backward_autograd(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd); m.impl(
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd); TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_autograd));
} }
} // namespace ops } // namespace ops
......
...@@ -154,8 +154,12 @@ at::Tensor roi_align_backward_autograd( ...@@ -154,8 +154,12 @@ at::Tensor roi_align_backward_autograd(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd); m.impl(
m.impl("_roi_align_backward", roi_align_backward_autograd); TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_autograd));
} }
} // namespace ops } // namespace ops
......
...@@ -139,8 +139,12 @@ at::Tensor roi_pool_backward_autograd( ...@@ -139,8 +139,12 @@ at::Tensor roi_pool_backward_autograd(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd); m.impl(
m.impl("_roi_pool_backward", roi_pool_backward_autograd); TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_autograd));
} }
} // namespace ops } // namespace ops
......
...@@ -1143,8 +1143,12 @@ deform_conv2d_backward_kernel( ...@@ -1143,8 +1143,12 @@ deform_conv2d_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel); m.impl(
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -109,7 +109,7 @@ at::Tensor nms_kernel( ...@@ -109,7 +109,7 @@ at::Tensor nms_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", nms_kernel); m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -422,8 +422,12 @@ at::Tensor ps_roi_align_backward_kernel( ...@@ -422,8 +422,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel); m.impl(
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -261,8 +261,12 @@ at::Tensor ps_roi_pool_backward_kernel( ...@@ -261,8 +261,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel); m.impl(
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -500,8 +500,12 @@ at::Tensor roi_align_backward_kernel( ...@@ -500,8 +500,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", roi_align_forward_kernel); m.impl(
m.impl("_roi_align_backward", roi_align_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -237,8 +237,12 @@ at::Tensor roi_pool_backward_kernel( ...@@ -237,8 +237,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_pool", roi_pool_forward_kernel); m.impl(
m.impl("_roi_pool_backward", roi_pool_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -1189,8 +1189,12 @@ deform_conv2d_backward_kernel( ...@@ -1189,8 +1189,12 @@ deform_conv2d_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel); m.impl(
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -168,7 +168,7 @@ at::Tensor nms_kernel( ...@@ -168,7 +168,7 @@ at::Tensor nms_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("nms", nms_kernel); m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -440,8 +440,12 @@ at::Tensor ps_roi_align_backward_kernel( ...@@ -440,8 +440,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel); m.impl(
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel); TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
} }
} // namespace ops } // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment