Unverified Commit 3711754a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add Torch Selective macros in all C++ Ops for better support on mobile (#3218)

* Adding TORCH_SELECTIVE_* macros on op registration.

* Adding torchvision namespace.
parent 4d2d8bb0
......@@ -45,7 +45,9 @@ at::Tensor deform_conv2d_autocast(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autocast));
}
} // namespace ops
......
......@@ -22,7 +22,7 @@ at::Tensor nms_autocast(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("nms", nms_autocast);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
}
} // namespace ops
......
......@@ -32,7 +32,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_align", ps_roi_align_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autocast));
}
} // namespace ops
......
......@@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_pool", ps_roi_pool_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autocast));
}
} // namespace ops
......
......@@ -31,7 +31,9 @@ at::Tensor roi_align_autocast(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", roi_align_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
}
} // namespace ops
......
......@@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_pool", roi_pool_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autocast));
}
} // namespace ops
......
......@@ -254,8 +254,12 @@ deform_conv2d_backward_autograd(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_autograd));
}
} // namespace ops
......
......@@ -154,8 +154,12 @@ at::Tensor ps_roi_align_backward_autograd(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_autograd));
}
} // namespace ops
......
......@@ -139,8 +139,12 @@ at::Tensor ps_roi_pool_backward_autograd(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_autograd));
}
} // namespace ops
......
......@@ -154,8 +154,12 @@ at::Tensor roi_align_backward_autograd(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_autograd));
}
} // namespace ops
......
......@@ -139,8 +139,12 @@ at::Tensor roi_pool_backward_autograd(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_autograd));
}
} // namespace ops
......
......@@ -1143,8 +1143,12 @@ deform_conv2d_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
}
} // namespace ops
......
......@@ -109,7 +109,7 @@ at::Tensor nms_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", nms_kernel);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
} // namespace ops
......
......@@ -422,8 +422,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}
} // namespace ops
......
......@@ -261,8 +261,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
}
} // namespace ops
......
......@@ -500,8 +500,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", roi_align_forward_kernel);
m.impl("_roi_align_backward", roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
}
} // namespace ops
......
......@@ -237,8 +237,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_pool", roi_pool_forward_kernel);
m.impl("_roi_pool_backward", roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}
} // namespace ops
......
......@@ -1189,8 +1189,12 @@ deform_conv2d_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
}
} // namespace ops
......
......@@ -168,7 +168,7 @@ at::Tensor nms_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("nms", nms_kernel);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
} // namespace ops
......
......@@ -440,8 +440,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}
} // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment