Unverified Commit c359d8d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Deprecate the C++ vision::models namespace (#4375)

* Add deprecation warnings on vision::models

* Change the C++ example.

* Chage readme.

* Update deprecation warning.
parent b1ae9a23
...@@ -106,9 +106,9 @@ otherwise, add the include and library paths in the environment variables ``TORC ...@@ -106,9 +106,9 @@ otherwise, add the include and library paths in the environment variables ``TORC
.. _libjpeg: http://ijg.org/ .. _libjpeg: http://ijg.org/
.. _libjpeg-turbo: https://libjpeg-turbo.org/ .. _libjpeg-turbo: https://libjpeg-turbo.org/
C++ API Using the models on C++
======= =======================
TorchVision also offers a C++ API that contains C++ equivalent of python models. TorchVision provides an example project for how to use the models on C++ using JIT Script.
Installation From source: Installation From source:
......
#include <iostream> #include <iostream>
#include <torch/script.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <torchvision/vision.h> #include <torchvision/vision.h>
#include <torchvision/models/resnet.h>
int main() int main() {
{ torch::DeviceType device_type;
auto model = vision::models::ResNet18(); device_type = torch::kCPU;
model->eval();
// Create a random input tensor and run it through the model. torch::jit::script::Module model;
auto in = torch::rand({1, 3, 10, 10}); try {
auto out = model->forward(in); std::cout << "Loading model\n";
// Deserialize the ScriptModule from a file using torch::jit::load().
model = torch::jit::load("resnet18.pt");
std::cout << "Model loaded\n";
} catch (const torch::Error& e) {
std::cout << "error loading the model\n";
return -1;
} catch (const std::exception& e) {
std::cout << "Other error: " << e.what() << "\n";
return -1;
}
std::cout << out.sizes(); // TorchScript models require a List[IValue] as input
std::vector<torch::jit::IValue> inputs;
// Create a random input tensor and run it through the model.
inputs.push_back(torch::rand({1, 3, 10, 10}));
auto out = model.forward(inputs);
std::cout << out << "\n";
if (torch::cuda::is_available()) { if (torch::cuda::is_available()) {
// Move model and inputs to GPU // Move model and inputs to GPU
model->to(torch::kCUDA); model.to(torch::kCUDA);
auto gpu_in = in.to(torch::kCUDA);
auto gpu_out = model->forward(gpu_in); // Add GPU inputs
inputs.clear();
torch::TensorOptions options = torch::TensorOptions{torch::kCUDA};
inputs.push_back(torch::rand({1, 3, 10, 10}, options));
std::cout << gpu_out.sizes(); auto gpu_out = model.forward(inputs);
std::cout << gpu_out << "\n";
} }
} }
import os.path as osp
import torch
import torchvision
HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))
model = torchvision.models.resnet18(pretrained=False)
model.eval()
traced_model = torch.jit.script(model)
traced_model.save("resnet18.pt")
...@@ -98,13 +98,18 @@ fi ...@@ -98,13 +98,18 @@ fi
# Compile and run the CPP example # Compile and run the CPP example
popd popd
cd examples/cpp/hello_world cd examples/cpp/hello_world
mkdir build mkdir build
# Trace model
python trace_model.py
cp resnet18.pt build
cd build cd build
cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch
if [[ "$OSTYPE" == "msys" ]]; then if [[ "$OSTYPE" == "msys" ]]; then
"$script_dir/windows/internal/vc_env_helper.bat" "$script_dir/windows/internal/build_cpp_example.bat" $PARALLELISM "$script_dir/windows/internal/vc_env_helper.bat" "$script_dir/windows/internal/build_cpp_example.bat" $PARALLELISM
mv resnet18.pt Release
cd Release cd Release
else else
make -j$PARALLELISM make -j$PARALLELISM
......
...@@ -32,6 +32,8 @@ AlexNetImpl::AlexNetImpl(int64_t num_classes) { ...@@ -32,6 +32,8 @@ AlexNetImpl::AlexNetImpl(int64_t num_classes) {
register_module("features", features); register_module("features", features);
register_module("classifier", classifier); register_module("classifier", classifier);
modelsimpl::deprecation_warning();
} }
torch::Tensor AlexNetImpl::forward(torch::Tensor x) { torch::Tensor AlexNetImpl::forward(torch::Tensor x) {
......
...@@ -142,6 +142,8 @@ DenseNetImpl::DenseNetImpl( ...@@ -142,6 +142,8 @@ DenseNetImpl::DenseNetImpl(
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} }
modelsimpl::deprecation_warning();
} }
torch::Tensor DenseNetImpl::forward(torch::Tensor x) { torch::Tensor DenseNetImpl::forward(torch::Tensor x) {
......
#include "googlenet.h" #include "googlenet.h"
#include "modelsimpl.h"
namespace vision { namespace vision {
namespace models { namespace models {
...@@ -143,6 +145,8 @@ GoogLeNetImpl::GoogLeNetImpl( ...@@ -143,6 +145,8 @@ GoogLeNetImpl::GoogLeNetImpl(
if (init_weights) if (init_weights)
_initialize_weights(); _initialize_weights();
modelsimpl::deprecation_warning();
} }
void GoogLeNetImpl::_initialize_weights() { void GoogLeNetImpl::_initialize_weights() {
......
#include "inception.h" #include "inception.h"
#include "modelsimpl.h"
namespace vision { namespace vision {
namespace models { namespace models {
...@@ -297,6 +299,8 @@ InceptionV3Impl::InceptionV3Impl( ...@@ -297,6 +299,8 @@ InceptionV3Impl::InceptionV3Impl(
register_module("Mixed_7b", Mixed_7b); register_module("Mixed_7b", Mixed_7b);
register_module("Mixed_7c", Mixed_7c); register_module("Mixed_7c", Mixed_7c);
register_module("fc", fc); register_module("fc", fc);
modelsimpl::deprecation_warning();
} }
InceptionV3Output InceptionV3Impl::forward(torch::Tensor x) { InceptionV3Output InceptionV3Impl::forward(torch::Tensor x) {
......
...@@ -158,6 +158,8 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) { ...@@ -158,6 +158,8 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
register_module("classifier", classifier); register_module("classifier", classifier);
_initialize_weights(); _initialize_weights();
modelsimpl::deprecation_warning();
} }
torch::Tensor MNASNetImpl::forward(torch::Tensor x) { torch::Tensor MNASNetImpl::forward(torch::Tensor x) {
......
...@@ -146,6 +146,8 @@ MobileNetV2Impl::MobileNetV2Impl( ...@@ -146,6 +146,8 @@ MobileNetV2Impl::MobileNetV2Impl(
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} }
} }
modelsimpl::deprecation_warning();
} }
torch::Tensor MobileNetV2Impl::forward(at::Tensor x) { torch::Tensor MobileNetV2Impl::forward(at::Tensor x) {
......
...@@ -34,6 +34,13 @@ inline bool double_compare(double a, double b) { ...@@ -34,6 +34,13 @@ inline bool double_compare(double a, double b) {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon(); return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
}; };
inline void deprecation_warning() {
TORCH_WARN_ONCE(
"The vision::models namespace is not actively maintained, use at "
"your own discretion. We recommend using Torch Script instead: "
"https://pytorch.org/tutorials/advanced/cpp_export.html");
}
} // namespace modelsimpl } // namespace modelsimpl
} // namespace models } // namespace models
} // namespace vision } // namespace vision
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <torch/nn.h> #include <torch/nn.h>
#include "../macros.h" #include "../macros.h"
#include "modelsimpl.h"
namespace vision { namespace vision {
namespace models { namespace models {
...@@ -164,6 +165,8 @@ ResNetImpl<Block>::ResNetImpl( ...@@ -164,6 +165,8 @@ ResNetImpl<Block>::ResNetImpl(
else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get())) else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get()))
torch::nn::init::constant_(M->bn2->weight, 0); torch::nn::init::constant_(M->bn2->weight, 0);
} }
modelsimpl::deprecation_warning();
} }
template <typename Block> template <typename Block>
......
...@@ -146,6 +146,8 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( ...@@ -146,6 +146,8 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
register_module("stage4", stage4); register_module("stage4", stage4);
register_module("conv2", conv5); register_module("conv2", conv5);
register_module("fc", fc); register_module("fc", fc);
modelsimpl::deprecation_warning();
} }
torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) { torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) {
......
...@@ -93,6 +93,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes) ...@@ -93,6 +93,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
if (M->options.bias()) if (M->options.bias())
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} }
modelsimpl::deprecation_warning();
} }
torch::Tensor SqueezeNetImpl::forward(torch::Tensor x) { torch::Tensor SqueezeNetImpl::forward(torch::Tensor x) {
......
...@@ -69,6 +69,8 @@ VGGImpl::VGGImpl( ...@@ -69,6 +69,8 @@ VGGImpl::VGGImpl(
if (initialize_weights) if (initialize_weights)
_initialize_weights(); _initialize_weights();
modelsimpl::deprecation_warning();
} }
torch::Tensor VGGImpl::forward(torch::Tensor x) { torch::Tensor VGGImpl::forward(torch::Tensor x) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment