Unverified Commit adfc15c4 authored by bmanga's avatar bmanga Committed by GitHub
Browse files

Ensure torchvision operators are added in C++ (#2798)

* Ensure torchvision operators are registered in C++ via weak symbols

* Add note to README on how to ensure that torchvision operators are available in C++

* Fix dllimport/dllexport on windows, format files

* Factor out common macros in single file

* Expose cuda_version in the API, use it to avoid pruning of ops initializer
parent 5e33cc87
...@@ -123,6 +123,11 @@ so make sure that it is also available to cmake via the ``CMAKE_PREFIX_PATH``. ...@@ -123,6 +123,11 @@ so make sure that it is also available to cmake via the ``CMAKE_PREFIX_PATH``.
For an example setup, take a look at ``examples/cpp/hello_world``. For an example setup, take a look at ``examples/cpp/hello_world``.
TorchVision Operators
---------------------
In order to get the torchvision operators registered with torch (eg. for the JIT), all you need to do is to ensure that you
:code:`#include <torchvision/vision.h>` in your project.
Documentation Documentation
============= =============
You can find the API documentation on the pytorch website: https://pytorch.org/docs/stable/torchvision/index.html You can find the API documentation on the pytorch website: https://pytorch.org/docs/stable/torchvision/index.html
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
#include "../macros.h"
#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu( VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
#include "../macros.h"
#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif
VISION_API at::Tensor ROIAlign_forward_cuda( VISION_API at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
......
#ifndef TORCHVISION_MACROS_H
#define TORCHVISION_MACROS_H
#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define VISION_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define VISION_INLINE_VARIABLE __declspec(selectany)
#else
#define VISION_INLINE_VARIABLE __attribute__((weak))
#endif
#endif
#endif // TORCHVISION_MACROS_H
...@@ -34,13 +34,15 @@ PyMODINIT_FUNC PyInit__C(void) { ...@@ -34,13 +34,15 @@ PyMODINIT_FUNC PyInit__C(void) {
#endif #endif
#endif #endif
int64_t _cuda_version() { namespace vision {
int64_t cuda_version() noexcept {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return CUDA_VERSION; return CUDA_VERSION;
#else #else
return -1; return -1;
#endif #endif
} }
} // namespace vision
TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
...@@ -53,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) { ...@@ -53,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) {
m.def("ps_roi_align", &ps_roi_align); m.def("ps_roi_align", &ps_roi_align);
m.def("ps_roi_pool", &ps_roi_pool); m.def("ps_roi_pool", &ps_roi_pool);
m.def("deform_conv2d", &deform_conv2d); m.def("deform_conv2d", &deform_conv2d);
m.def("_cuda_version", &_cuda_version); m.def("_cuda_version", &vision::cuda_version);
} }
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
......
...@@ -2,5 +2,18 @@ ...@@ -2,5 +2,18 @@
#define VISION_H #define VISION_H
#include <torchvision/models/models.h> #include <torchvision/models/models.h>
#include <cstdint>
#include "macros.h"
namespace vision {
VISION_API int64_t cuda_version() noexcept;
namespace detail {
// Dummy variable to reference a symbol from vision.cpp.
// This ensures that the torchvision library and the ops registration
// initializers are not pruned.
VISION_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace vision
#endif // VISION_H #endif // VISION_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment