"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2de0516442da20433efc7d2ed04bea1b429e0dee"
Unverified Commit 21f70c17 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Remove cpp model in v0.14 due to deprecation (#6632)

* Remove cpp models

* Also remove the whole models folder in csrc

* Cleanup test for cpp model
parent c4c28dff
...@@ -140,8 +140,6 @@ def get_extensions(): ...@@ -140,8 +140,6 @@ def get_extensions():
) )
print("Compiling extensions with following flags:") print("Compiling extensions with following flags:")
compile_cpp_tests = os.getenv("WITH_CPP_MODELS_TEST", "0") == "1"
print(f" WITH_CPP_MODELS_TEST: {compile_cpp_tests}")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1" force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}") print(f" FORCE_CUDA: {force_cuda}")
debug_mode = os.getenv("DEBUG", "0") == "1" debug_mode = os.getenv("DEBUG", "0") == "1"
...@@ -189,18 +187,6 @@ def get_extensions(): ...@@ -189,18 +187,6 @@ def get_extensions():
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
if compile_cpp_tests:
print("Compiling CPP tests")
test_dir = os.path.join(this_dir, "test")
models_dir = os.path.join(this_dir, "torchvision", "csrc", "models")
test_file = glob.glob(os.path.join(test_dir, "*.cpp"))
source_models = glob.glob(os.path.join(models_dir, "*.cpp"))
test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models
tests_include_dirs = [test_dir, models_dir]
define_macros = [] define_macros = []
extra_compile_args = {"cxx": []} extra_compile_args = {"cxx": []}
...@@ -247,16 +233,6 @@ def get_extensions(): ...@@ -247,16 +233,6 @@ def get_extensions():
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
) )
] ]
if compile_cpp_tests:
ext_modules.append(
extension(
"torchvision._C_tests",
tests,
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
)
# ------------------- Torchvision extra extensions ------------------------ # ------------------- Torchvision extra extensions ------------------------
vision_include = os.environ.get("TORCHVISION_INCLUDE", None) vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
......
import os
import sys
import unittest
import torch
import torchvision.transforms.functional as F
from PIL import Image
from torchvision import models
try:
from torchvision import _C_tests
except ImportError:
_C_tests = None
def process_model(model, tensor, func, name):
model.eval()
traced_script_module = torch.jit.trace(model, tensor)
traced_script_module.save("model.pt")
py_output = model.forward(tensor)
cpp_output = func("model.pt", tensor)
assert torch.allclose(py_output, cpp_output), "Output mismatch of " + name + " models"
def read_image1():
image_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
image = Image.open(image_path)
image = image.resize((224, 224))
x = F.pil_to_tensor(image)
x = F.convert_image_dtype(x)
return x.view(1, 3, 224, 224)
def read_image2():
image_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
image = Image.open(image_path)
image = image.resize((299, 299))
x = F.pil_to_tensor(image)
x = F.convert_image_dtype(x)
x = x.view(1, 3, 299, 299)
return torch.cat([x, x], 0)
@unittest.skipIf(
sys.platform == "darwin" or True,
"C++ models are broken on OS X at the moment, and there's a BC breakage on main; "
"see https://github.com/pytorch/vision/issues/1191",
)
class Tester(unittest.TestCase):
image = read_image1()
def test_alexnet(self):
process_model(models.alexnet(), self.image, _C_tests.forward_alexnet, "Alexnet")
def test_vgg11(self):
process_model(models.vgg11(), self.image, _C_tests.forward_vgg11, "VGG11")
def test_vgg13(self):
process_model(models.vgg13(), self.image, _C_tests.forward_vgg13, "VGG13")
def test_vgg16(self):
process_model(models.vgg16(), self.image, _C_tests.forward_vgg16, "VGG16")
def test_vgg19(self):
process_model(models.vgg19(), self.image, _C_tests.forward_vgg19, "VGG19")
def test_vgg11_bn(self):
process_model(models.vgg11_bn(), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
def test_vgg13_bn(self):
process_model(models.vgg13_bn(), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
def test_vgg16_bn(self):
process_model(models.vgg16_bn(), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
def test_vgg19_bn(self):
process_model(models.vgg19_bn(), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
def test_resnet18(self):
process_model(models.resnet18(), self.image, _C_tests.forward_resnet18, "Resnet18")
def test_resnet34(self):
process_model(models.resnet34(), self.image, _C_tests.forward_resnet34, "Resnet34")
def test_resnet50(self):
process_model(models.resnet50(), self.image, _C_tests.forward_resnet50, "Resnet50")
def test_resnet101(self):
process_model(models.resnet101(), self.image, _C_tests.forward_resnet101, "Resnet101")
def test_resnet152(self):
process_model(models.resnet152(), self.image, _C_tests.forward_resnet152, "Resnet152")
def test_resnext50_32x4d(self):
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d")
def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, "ResNext101_32x8d")
def test_wide_resnet50_2(self):
process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, "WideResNet50_2")
def test_wide_resnet101_2(self):
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2")
def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0")
def test_squeezenet1_1(self):
process_model(models.squeezenet1_1(), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1")
def test_densenet121(self):
process_model(models.densenet121(), self.image, _C_tests.forward_densenet121, "Densenet121")
def test_densenet169(self):
process_model(models.densenet169(), self.image, _C_tests.forward_densenet169, "Densenet169")
def test_densenet201(self):
process_model(models.densenet201(), self.image, _C_tests.forward_densenet201, "Densenet201")
def test_densenet161(self):
process_model(models.densenet161(), self.image, _C_tests.forward_densenet161, "Densenet161")
def test_mobilenet_v2(self):
process_model(models.mobilenet_v2(), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
def test_googlenet(self):
process_model(models.googlenet(), self.image, _C_tests.forward_googlenet, "GoogLeNet")
def test_mnasnet0_5(self):
process_model(models.mnasnet0_5(), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
def test_mnasnet0_75(self):
process_model(models.mnasnet0_75(), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
def test_mnasnet1_0(self):
process_model(models.mnasnet1_0(), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
def test_mnasnet1_3(self):
process_model(models.mnasnet1_3(), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
def test_inception_v3(self):
self.image = read_image2()
process_model(models.inception_v3(), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
if __name__ == "__main__":
unittest.main()
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include "../torchvision/csrc/models/models.h"
using namespace vision::models;
template <typename Model>
torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) {
Model network;
torch::load(network, input_path);
network->eval();
return network->forward(x);
}
torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) {
return forward_model<AlexNet>(input_path, x);
}
torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11>(input_path, x);
}
torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13>(input_path, x);
}
torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16>(input_path, x);
}
torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19>(input_path, x);
}
torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11BN>(input_path, x);
}
torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13BN>(input_path, x);
}
torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16BN>(input_path, x);
}
torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19BN>(input_path, x);
}
torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet18>(input_path, x);
}
torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet34>(input_path, x);
}
torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet50>(input_path, x);
}
torch::Tensor forward_resnet101(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet101>(input_path, x);
}
torch::Tensor forward_resnet152(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet152>(input_path, x);
}
torch::Tensor forward_resnext50_32x4d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext50_32x4d>(input_path, x);
}
torch::Tensor forward_resnext101_32x8d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext101_32x8d>(input_path, x);
}
torch::Tensor forward_wide_resnet50_2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<WideResNet50_2>(input_path, x);
}
torch::Tensor forward_wide_resnet101_2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<WideResNet101_2>(input_path, x);
}
torch::Tensor forward_squeezenet1_0(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_0>(input_path, x);
}
torch::Tensor forward_squeezenet1_1(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_1>(input_path, x);
}
torch::Tensor forward_densenet121(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet121>(input_path, x);
}
torch::Tensor forward_densenet169(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet169>(input_path, x);
}
torch::Tensor forward_densenet201(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet201>(input_path, x);
}
torch::Tensor forward_densenet161(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet161>(input_path, x);
}
torch::Tensor forward_mobilenetv2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<MobileNetV2>(input_path, x);
}
torch::Tensor forward_googlenet(
const std::string& input_path,
torch::Tensor x) {
GoogLeNet network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}
torch::Tensor forward_inceptionv3(
const std::string& input_path,
torch::Tensor x) {
InceptionV3 network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}
torch::Tensor forward_mnasnet0_5(const std::string& input_path, torch::Tensor x) {
return forward_model<MNASNet0_5>(input_path, x);
}
torch::Tensor forward_mnasnet0_75(const std::string& input_path, torch::Tensor x) {
return forward_model<MNASNet0_75>(input_path, x);
}
torch::Tensor forward_mnasnet1_0(const std::string& input_path, torch::Tensor x) {
return forward_model<MNASNet1_0>(input_path, x);
}
torch::Tensor forward_mnasnet1_3(const std::string& input_path, torch::Tensor x) {
return forward_model<MNASNet1_3>(input_path, x);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_alexnet", &forward_alexnet, "forward_alexnet");
m.def("forward_vgg11", &forward_vgg11, "forward_vgg11");
m.def("forward_vgg13", &forward_vgg13, "forward_vgg13");
m.def("forward_vgg16", &forward_vgg16, "forward_vgg16");
m.def("forward_vgg19", &forward_vgg19, "forward_vgg19");
m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn");
m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn");
m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn");
m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn");
m.def("forward_resnet18", &forward_resnet18, "forward_resnet18");
m.def("forward_resnet34", &forward_resnet34, "forward_resnet34");
m.def("forward_resnet50", &forward_resnet50, "forward_resnet50");
m.def("forward_resnet101", &forward_resnet101, "forward_resnet101");
m.def("forward_resnet152", &forward_resnet152, "forward_resnet152");
m.def(
"forward_resnext50_32x4d",
&forward_resnext50_32x4d,
"forward_resnext50_32x4d");
m.def(
"forward_resnext101_32x8d",
&forward_resnext101_32x8d,
"forward_resnext101_32x8d");
m.def(
"forward_wide_resnet50_2",
&forward_wide_resnet50_2,
"forward_wide_resnet50_2");
m.def(
"forward_wide_resnet101_2",
&forward_wide_resnet101_2,
"forward_wide_resnet101_2");
m.def(
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
m.def(
"forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1");
m.def("forward_densenet121", &forward_densenet121, "forward_densenet121");
m.def("forward_densenet169", &forward_densenet169, "forward_densenet169");
m.def("forward_densenet201", &forward_densenet201, "forward_densenet201");
m.def("forward_densenet161", &forward_densenet161, "forward_densenet161");
m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2");
m.def("forward_googlenet", &forward_googlenet, "forward_googlenet");
m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3");
m.def("forward_mnasnet0_5", &forward_mnasnet0_5, "forward_mnasnet0_5");
m.def("forward_mnasnet0_75", &forward_mnasnet0_75, "forward_mnasnet0_75");
m.def("forward_mnasnet1_0", &forward_mnasnet1_0, "forward_mnasnet1_0");
m.def("forward_mnasnet1_3", &forward_mnasnet1_3, "forward_mnasnet1_3");
}
#include "alexnet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
AlexNetImpl::AlexNetImpl(int64_t num_classes) {
features = torch::nn::Sequential(
torch::nn::Conv2d(
torch::nn::Conv2dOptions(3, 64, 11).stride(4).padding(2)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Functional(modelsimpl::max_pool2d, 3, 2),
torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 192, 5).padding(2)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Functional(modelsimpl::max_pool2d, 3, 2),
torch::nn::Conv2d(torch::nn::Conv2dOptions(192, 384, 3).padding(1)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Conv2d(torch::nn::Conv2dOptions(384, 256, 3).padding(1)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Functional(modelsimpl::max_pool2d, 3, 2));
classifier = torch::nn::Sequential(
torch::nn::Dropout(),
torch::nn::Linear(256 * 6 * 6, 4096),
torch::nn::Functional(torch::relu),
torch::nn::Dropout(),
torch::nn::Linear(4096, 4096),
torch::nn::Functional(torch::relu),
torch::nn::Linear(4096, num_classes));
register_module("features", features);
register_module("classifier", classifier);
modelsimpl::deprecation_warning();
}
torch::Tensor AlexNetImpl::forward(torch::Tensor x) {
x = features->forward(x);
x = torch::adaptive_avg_pool2d(x, {6, 6});
x = x.view({x.size(0), -1});
x = classifier->forward(x);
return x;
}
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
namespace vision {
namespace models {
// AlexNet model architecture from the
// "One weird trick..." <https://arxiv.org/abs/1404.5997> paper.
struct VISION_API AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr};
explicit AlexNetImpl(int64_t num_classes = 1000);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(AlexNet);
} // namespace models
} // namespace vision
#include "densenet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
struct _DenseLayerImpl : torch::nn::SequentialImpl {
double drop_rate;
_DenseLayerImpl(
int64_t num_input_features,
int64_t growth_rate,
int64_t bn_size,
double drop_rate)
: drop_rate(drop_rate) {
push_back("norm1", torch::nn::BatchNorm2d(num_input_features));
push_back("relu1", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv1",
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
.stride(1)
.bias(false)));
push_back("norm2", torch::nn::BatchNorm2d(bn_size * growth_rate));
push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv2",
torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3)
.stride(1)
.padding(1)
.bias(false)));
}
torch::Tensor forward(torch::Tensor x) {
auto new_features = torch::nn::SequentialImpl::forward(x);
if (drop_rate > 0)
new_features =
torch::dropout(new_features, drop_rate, this->is_training());
return torch::cat({x, new_features}, 1);
}
};
TORCH_MODULE(_DenseLayer);
struct _DenseBlockImpl : torch::nn::SequentialImpl {
_DenseBlockImpl(
int64_t num_layers,
int64_t num_input_features,
int64_t bn_size,
int64_t growth_rate,
double drop_rate) {
for (int64_t i = 0; i < num_layers; ++i) {
auto layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate,
bn_size,
drop_rate);
push_back("denselayer" + std::to_string(i + 1), layer);
}
}
torch::Tensor forward(torch::Tensor x) {
return torch::nn::SequentialImpl::forward(x);
}
};
TORCH_MODULE(_DenseBlock);
struct _TransitionImpl : torch::nn::SequentialImpl {
_TransitionImpl(int64_t num_input_features, int64_t num_output_features) {
push_back("norm", torch::nn::BatchNorm2d(num_input_features));
push_back("relu ", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv",
torch::nn::Conv2d(Options(num_input_features, num_output_features, 1)
.stride(1)
.bias(false)));
push_back("pool", torch::nn::Functional([](const torch::Tensor& input) {
return torch::avg_pool2d(input, 2, 2, 0, false, true);
}));
}
torch::Tensor forward(torch::Tensor x) {
return torch::nn::SequentialImpl::forward(x);
}
};
TORCH_MODULE(_Transition);
DenseNetImpl::DenseNetImpl(
int64_t num_classes,
int64_t growth_rate,
const std::vector<int64_t>& block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate) {
// First convolution
features = torch::nn::Sequential();
features->push_back(
"conv0",
torch::nn::Conv2d(
Options(3, num_init_features, 7).stride(2).padding(3).bias(false)));
features->push_back("norm0", torch::nn::BatchNorm2d(num_init_features));
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
features->push_back(
"pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false));
// Each denseblock
auto num_features = num_init_features;
for (size_t i = 0; i < block_config.size(); ++i) {
auto num_layers = block_config[i];
_DenseBlock block(
num_layers, num_features, bn_size, growth_rate, drop_rate);
features->push_back("denseblock" + std::to_string(i + 1), block);
num_features = num_features + num_layers * growth_rate;
if (i != block_config.size() - 1) {
auto trans = _Transition(num_features, num_features / 2);
features->push_back("transition" + std::to_string(i + 1), trans);
num_features = num_features / 2;
}
}
// Final batch norm
features->push_back("norm5", torch::nn::BatchNorm2d(num_features));
// Linear layer
classifier = torch::nn::Linear(num_features, num_classes);
register_module("features", features);
register_module("classifier", classifier);
// Official init from torch repo.
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(M->weight);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::constant_(M->bias, 0);
}
modelsimpl::deprecation_warning();
}
torch::Tensor DenseNetImpl::forward(torch::Tensor x) {
auto features = this->features->forward(x);
auto out = torch::relu_(features);
out = torch::adaptive_avg_pool2d(out, {1, 1});
out = out.view({features.size(0), -1});
out = this->classifier->forward(out);
return out;
}
DenseNet121Impl::DenseNet121Impl(
int64_t num_classes,
int64_t growth_rate,
const std::vector<int64_t>& block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
DenseNet169Impl::DenseNet169Impl(
int64_t num_classes,
int64_t growth_rate,
const std::vector<int64_t>& block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
DenseNet201Impl::DenseNet201Impl(
int64_t num_classes,
int64_t growth_rate,
const std::vector<int64_t>& block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
DenseNet161Impl::DenseNet161Impl(
int64_t num_classes,
int64_t growth_rate,
const std::vector<int64_t>& block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
namespace vision {
namespace models {
// Densenet-BC model class, based on
// "Densely Connected Convolutional Networks"
// <https://arxiv.org/pdf/1608.06993.pdf>
// Args:
// num_classes (int) - number of classification classes
// growth_rate (int) - how many filters to add each layer (`k` in paper)
// block_config (list of 4 ints) - how many layers in each pooling block
// num_init_features (int) - the number of filters to learn in the first
// convolution layer
// bn_size (int) - multiplicative factor for number of bottle neck layers
// (i.e. bn_size * k features in the bottleneck layer)
// drop_rate (float) - dropout rate after each dense layer
struct VISION_API DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr};
explicit DenseNetImpl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
torch::Tensor forward(torch::Tensor x);
};
struct VISION_API DenseNet121Impl : DenseNetImpl {
explicit DenseNet121Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
};
struct VISION_API DenseNet169Impl : DenseNetImpl {
explicit DenseNet169Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 32, 32},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
};
struct VISION_API DenseNet201Impl : DenseNetImpl {
explicit DenseNet201Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 48, 32},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
};
struct VISION_API DenseNet161Impl : DenseNetImpl {
explicit DenseNet161Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 48,
const std::vector<int64_t>& block_config = {6, 12, 36, 24},
int64_t num_init_features = 96,
int64_t bn_size = 4,
double drop_rate = 0);
};
TORCH_MODULE(DenseNet);
TORCH_MODULE(DenseNet121);
TORCH_MODULE(DenseNet169);
TORCH_MODULE(DenseNet201);
TORCH_MODULE(DenseNet161);
} // namespace models
} // namespace vision
#include "googlenet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
namespace _googlenetimpl {
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
options.bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv);
register_module("bn", bn);
}
torch::Tensor BasicConv2dImpl::forward(torch::Tensor x) {
x = conv->forward(x);
x = bn->forward(x);
return x.relu_();
}
InceptionImpl::InceptionImpl(
int64_t in_channels,
int64_t ch1x1,
int64_t ch3x3red,
int64_t ch3x3,
int64_t ch5x5red,
int64_t ch5x5,
int64_t pool_proj) {
branch1 = BasicConv2d(Options(in_channels, ch1x1, 1));
branch2->push_back(BasicConv2d(Options(in_channels, ch3x3red, 1)));
branch2->push_back(BasicConv2d(Options(ch3x3red, ch3x3, 3).padding(1)));
branch3->push_back(BasicConv2d(Options(in_channels, ch5x5red, 1)));
branch3->push_back(BasicConv2d(Options(ch5x5red, ch5x5, 3).padding(1)));
branch4->push_back(
torch::nn::Functional(torch::max_pool2d, 3, 1, 1, 1, true));
branch4->push_back(BasicConv2d(Options(in_channels, pool_proj, 1)));
register_module("branch1", branch1);
register_module("branch2", branch2);
register_module("branch3", branch3);
register_module("branch4", branch4);
}
torch::Tensor InceptionImpl::forward(torch::Tensor x) {
auto b1 = branch1->forward(x);
auto b2 = branch2->forward(x);
auto b3 = branch3->forward(x);
auto b4 = branch4->forward(x);
return torch::cat({b1, b2, b3, b4}, 1);
}
InceptionAuxImpl::InceptionAuxImpl(int64_t in_channels, int64_t num_classes) {
conv = BasicConv2d(Options(in_channels, 128, 1));
fc1 = torch::nn::Linear(2048, 1024);
fc2 = torch::nn::Linear(1024, num_classes);
register_module("conv", conv);
register_module("fc1", fc1);
register_module("fc2", fc2);
}
torch::Tensor InceptionAuxImpl::forward(at::Tensor x) {
// aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = torch::adaptive_avg_pool2d(x, {4, 4});
// aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = conv->forward(x);
// N x 128 x 4 x 4
x = x.view({x.size(0), -1});
// N x 2048
x = fc1->forward(x).relu_();
// N x 2048
x = torch::dropout(x, 0.7, is_training());
// N x 2048
x = fc2->forward(x);
// N x 1024
return x;
}
} // namespace _googlenetimpl
GoogLeNetImpl::GoogLeNetImpl(
int64_t num_classes,
bool aux_logits,
bool transform_input,
bool init_weights) {
this->aux_logits = aux_logits;
this->transform_input = transform_input;
conv1 = _googlenetimpl::BasicConv2d(Options(3, 64, 7).stride(2).padding(3));
conv2 = _googlenetimpl::BasicConv2d(Options(64, 64, 1));
conv3 = _googlenetimpl::BasicConv2d(Options(64, 192, 3).padding(1));
inception3a = _googlenetimpl::Inception(192, 64, 96, 128, 16, 32, 32);
inception3b = _googlenetimpl::Inception(256, 128, 128, 192, 32, 96, 64);
inception4a = _googlenetimpl::Inception(480, 192, 96, 208, 16, 48, 64);
inception4b = _googlenetimpl::Inception(512, 160, 112, 224, 24, 64, 64);
inception4c = _googlenetimpl::Inception(512, 128, 128, 256, 24, 64, 64);
inception4d = _googlenetimpl::Inception(512, 112, 144, 288, 32, 64, 64);
inception4e = _googlenetimpl::Inception(528, 256, 160, 320, 32, 128, 128);
inception5a = _googlenetimpl::Inception(832, 256, 160, 320, 32, 128, 128);
inception5b = _googlenetimpl::Inception(832, 384, 192, 384, 48, 128, 128);
if (aux_logits) {
aux1 = _googlenetimpl::InceptionAux(512, num_classes);
aux2 = _googlenetimpl::InceptionAux(528, num_classes);
register_module("aux1", aux1);
register_module("aux2", aux2);
}
dropout = torch::nn::Dropout(0.2);
fc = torch::nn::Linear(1024, num_classes);
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("inception3a", inception3a);
register_module("inception3b", inception3b);
register_module("inception4a", inception4a);
register_module("inception4b", inception4b);
register_module("inception4c", inception4c);
register_module("inception4d", inception4d);
register_module("inception4e", inception4e);
register_module("inception5a", inception5a);
register_module("inception5b", inception5b);
register_module("dropout", dropout);
register_module("fc", fc);
if (init_weights)
_initialize_weights();
modelsimpl::deprecation_warning();
}
void GoogLeNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::normal_(M->weight); // Note: used instead of truncated
// normal initialization
else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::normal_(M->weight); // Note: used instead of truncated
// normal initialization
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
}
}
}
GoogLeNetOutput GoogLeNetImpl::forward(torch::Tensor x) {
if (transform_input) {
auto x_ch0 = torch::unsqueeze(x.select(1, 0), 1) * (0.229 / 0.5) +
(0.485 - 0.5) / 0.5;
auto x_ch1 = torch::unsqueeze(x.select(1, 1), 1) * (0.224 / 0.5) +
(0.456 - 0.5) / 0.5;
auto x_ch2 = torch::unsqueeze(x.select(1, 2), 1) * (0.225 / 0.5) +
(0.406 - 0.5) / 0.5;
x = torch::cat({x_ch0, x_ch1, x_ch2}, 1);
}
// N x 3 x 224 x 224
x = conv1->forward(x);
// N x 64 x 112 x 112
x = torch::max_pool2d(x, 3, 2, 0, 1, true);
// N x 64 x 56 x 56
x = conv2->forward(x);
// N x 64 x 56 x 56
x = conv3->forward(x);
// N x 192 x 56 x 56
x = torch::max_pool2d(x, 3, 2, 0, 1, true);
// N x 192 x 28 x 28
x = inception3a->forward(x);
// N x 256 x 28 x 28
x = inception3b->forward(x);
// N x 480 x 28 x 28
x = torch::max_pool2d(x, 3, 2, 0, 1, true);
// N x 480 x 14 x 14
x = inception4a->forward(x);
// N x 512 x 14 x 14
torch::Tensor aux1;
if (is_training() && aux_logits)
aux1 = this->aux1->forward(x);
x = inception4b->forward(x);
// N x 512 x 14 x 14
x = inception4c->forward(x);
// N x 512 x 14 x 14
x = inception4d->forward(x);
// N x 528 x 14 x 14
torch::Tensor aux2;
if (is_training() && aux_logits)
aux2 = this->aux2->forward(x);
x = inception4e(x);
// N x 832 x 14 x 14
x = torch::max_pool2d(x, 2, 2, 0, 1, true);
// N x 832 x 7 x 7
x = inception5a(x);
// N x 832 x 7 x 7
x = inception5b(x);
// N x 1024 x 7 x 7
x = torch::adaptive_avg_pool2d(x, {1, 1});
// N x 1024 x 1 x 1
x = x.view({x.size(0), -1});
// N x 1024
x = dropout->forward(x);
x = fc->forward(x);
// N x 1000(num_classes)
return {x, aux1, aux2};
}
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
namespace vision {
namespace models {
namespace _googlenetimpl {
struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};
explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(BasicConv2d);
struct VISION_API InceptionImpl : torch::nn::Module {
BasicConv2d branch1{nullptr};
torch::nn::Sequential branch2, branch3, branch4;
InceptionImpl(
int64_t in_channels,
int64_t ch1x1,
int64_t ch3x3red,
int64_t ch3x3,
int64_t ch5x5red,
int64_t ch5x5,
int64_t pool_proj);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(Inception);
struct VISION_API InceptionAuxImpl : torch::nn::Module {
BasicConv2d conv{nullptr};
torch::nn::Linear fc1{nullptr}, fc2{nullptr};
InceptionAuxImpl(int64_t in_channels, int64_t num_classes);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(InceptionAux);
} // namespace _googlenetimpl
struct VISION_API GoogLeNetOutput {
torch::Tensor output;
torch::Tensor aux1;
torch::Tensor aux2;
};
struct VISION_API GoogLeNetImpl : torch::nn::Module {
bool aux_logits, transform_input;
_googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
_googlenetimpl::Inception inception3a{nullptr}, inception3b{nullptr},
inception4a{nullptr}, inception4b{nullptr}, inception4c{nullptr},
inception4d{nullptr}, inception4e{nullptr}, inception5a{nullptr},
inception5b{nullptr};
_googlenetimpl::InceptionAux aux1{nullptr}, aux2{nullptr};
torch::nn::Dropout dropout{nullptr};
torch::nn::Linear fc{nullptr};
explicit GoogLeNetImpl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false,
bool init_weights = true);
void _initialize_weights();
GoogLeNetOutput forward(torch::Tensor x);
};
TORCH_MODULE(GoogLeNet);
} // namespace models
} // namespace vision
#include "inception.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
namespace _inceptionimpl {
BasicConv2dImpl::BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev) {
options.bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv);
register_module("bn", bn);
torch::nn::init::normal_(
conv->weight,
0,
std_dev); // Note: used instead of truncated normal initialization
torch::nn::init::constant_(bn->weight, 1);
torch::nn::init::constant_(bn->bias, 0);
}
torch::Tensor BasicConv2dImpl::forward(torch::Tensor x) {
x = conv->forward(x);
x = bn->forward(x);
return torch::relu_(x);
}
InceptionAImpl::InceptionAImpl(int64_t in_channels, int64_t pool_features)
: branch1x1(Options(in_channels, 64, 1)),
branch5x5_1(Options(in_channels, 48, 1)),
branch5x5_2(Options(48, 64, 5).padding(2)),
branch3x3dbl_1(Options(in_channels, 64, 1)),
branch3x3dbl_2(Options(64, 96, 3).padding(1)),
branch3x3dbl_3(Options(96, 96, 3).padding(1)),
branch_pool(Options(in_channels, pool_features, 1)) {
register_module("branch1x1", branch1x1);
register_module("branch5x5_1", branch5x5_1);
register_module("branch5x5_2", branch5x5_2);
register_module("branch3x3dbl_1", branch3x3dbl_1);
register_module("branch3x3dbl_2", branch3x3dbl_2);
register_module("branch3x3dbl_3", branch3x3dbl_3);
register_module("branch_pool", branch_pool);
}
torch::Tensor InceptionAImpl::forward(const torch::Tensor& x) {
auto branch1x1 = this->branch1x1->forward(x);
auto branch5x5 = this->branch5x5_1->forward(x);
branch5x5 = this->branch5x5_2->forward(branch5x5);
auto branch3x3dbl = this->branch3x3dbl_1->forward(x);
branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl);
branch3x3dbl = this->branch3x3dbl_3->forward(branch3x3dbl);
auto branch_pool = torch::avg_pool2d(x, 3, 1, 1);
branch_pool = this->branch_pool->forward(branch_pool);
return torch::cat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, 1);
}
InceptionBImpl::InceptionBImpl(int64_t in_channels)
: branch3x3(Options(in_channels, 384, 3).stride(2)),
branch3x3dbl_1(Options(in_channels, 64, 1)),
branch3x3dbl_2(Options(64, 96, 3).padding(1)),
branch3x3dbl_3(Options(96, 96, 3).stride(2)) {
register_module("branch3x3", branch3x3);
register_module("branch3x3dbl_1", branch3x3dbl_1);
register_module("branch3x3dbl_2", branch3x3dbl_2);
register_module("branch3x3dbl_3", branch3x3dbl_3);
}
torch::Tensor InceptionBImpl::forward(const torch::Tensor& x) {
auto branch3x3 = this->branch3x3->forward(x);
auto branch3x3dbl = this->branch3x3dbl_1->forward(x);
branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl);
branch3x3dbl = this->branch3x3dbl_3->forward(branch3x3dbl);
auto branch_pool = torch::max_pool2d(x, 3, 2);
return torch::cat({branch3x3, branch3x3dbl, branch_pool}, 1);
}
InceptionCImpl::InceptionCImpl(int64_t in_channels, int64_t channels_7x7) {
branch1x1 = BasicConv2d(Options(in_channels, 192, 1));
auto c7 = channels_7x7;
branch7x7_1 = BasicConv2d(Options(in_channels, c7, 1));
branch7x7_2 = BasicConv2d(Options(c7, c7, {1, 7}).padding({0, 3}));
branch7x7_3 = BasicConv2d(Options(c7, 192, {7, 1}).padding({3, 0}));
branch7x7dbl_1 = BasicConv2d(Options(in_channels, c7, 1));
branch7x7dbl_2 = BasicConv2d(Options(c7, c7, {7, 1}).padding({3, 0}));
branch7x7dbl_3 = BasicConv2d(Options(c7, c7, {1, 7}).padding({0, 3}));
branch7x7dbl_4 = BasicConv2d(Options(c7, c7, {7, 1}).padding({3, 0}));
branch7x7dbl_5 = BasicConv2d(Options(c7, 192, {1, 7}).padding({0, 3}));
branch_pool = BasicConv2d(Options(in_channels, 192, 1));
register_module("branch1x1", branch1x1);
register_module("branch7x7_1", branch7x7_1);
register_module("branch7x7_2", branch7x7_2);
register_module("branch7x7_3", branch7x7_3);
register_module("branch7x7dbl_1", branch7x7dbl_1);
register_module("branch7x7dbl_2", branch7x7dbl_2);
register_module("branch7x7dbl_3", branch7x7dbl_3);
register_module("branch7x7dbl_4", branch7x7dbl_4);
register_module("branch7x7dbl_5", branch7x7dbl_5);
register_module("branch_pool", branch_pool);
}
torch::Tensor InceptionCImpl::forward(const torch::Tensor& x) {
auto branch1x1 = this->branch1x1->forward(x);
auto branch7x7 = this->branch7x7_1->forward(x);
branch7x7 = this->branch7x7_2->forward(branch7x7);
branch7x7 = this->branch7x7_3->forward(branch7x7);
auto branch7x7dbl = this->branch7x7dbl_1->forward(x);
branch7x7dbl = this->branch7x7dbl_2->forward(branch7x7dbl);
branch7x7dbl = this->branch7x7dbl_3->forward(branch7x7dbl);
branch7x7dbl = this->branch7x7dbl_4->forward(branch7x7dbl);
branch7x7dbl = this->branch7x7dbl_5->forward(branch7x7dbl);
auto branch_pool = torch::avg_pool2d(x, 3, 1, 1);
branch_pool = this->branch_pool->forward(branch_pool);
return torch::cat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, 1);
}
InceptionDImpl::InceptionDImpl(int64_t in_channels)
: branch3x3_1(Options(in_channels, 192, 1)),
branch3x3_2(Options(192, 320, 3).stride(2)),
branch7x7x3_1(Options(in_channels, 192, 1)),
branch7x7x3_2(Options(192, 192, {1, 7}).padding({0, 3})),
branch7x7x3_3(Options(192, 192, {7, 1}).padding({3, 0})),
branch7x7x3_4(Options(192, 192, 3).stride(2))
{
register_module("branch3x3_1", branch3x3_1);
register_module("branch3x3_2", branch3x3_2);
register_module("branch7x7x3_1", branch7x7x3_1);
register_module("branch7x7x3_2", branch7x7x3_2);
register_module("branch7x7x3_3", branch7x7x3_3);
register_module("branch7x7x3_4", branch7x7x3_4);
}
torch::Tensor InceptionDImpl::forward(const torch::Tensor& x) {
auto branch3x3 = this->branch3x3_1->forward(x);
branch3x3 = this->branch3x3_2->forward(branch3x3);
auto branch7x7x3 = this->branch7x7x3_1->forward(x);
branch7x7x3 = this->branch7x7x3_2->forward(branch7x7x3);
branch7x7x3 = this->branch7x7x3_3->forward(branch7x7x3);
branch7x7x3 = this->branch7x7x3_4->forward(branch7x7x3);
auto branch_pool = torch::max_pool2d(x, 3, 2);
return torch::cat({branch3x3, branch7x7x3, branch_pool}, 1);
}
InceptionEImpl::InceptionEImpl(int64_t in_channels)
: branch1x1(Options(in_channels, 320, 1)),
branch3x3_1(Options(in_channels, 384, 1)),
branch3x3_2a(Options(384, 384, {1, 3}).padding({0, 1})),
branch3x3_2b(Options(384, 384, {3, 1}).padding({1, 0})),
branch3x3dbl_1(Options(in_channels, 448, 1)),
branch3x3dbl_2(Options(448, 384, 3).padding(1)),
branch3x3dbl_3a(Options(384, 384, {1, 3}).padding({0, 1})),
branch3x3dbl_3b(Options(384, 384, {3, 1}).padding({1, 0})),
branch_pool(Options(in_channels, 192, 1)) {
register_module("branch1x1", branch1x1);
register_module("branch3x3_1", branch3x3_1);
register_module("branch3x3_2a", branch3x3_2a);
register_module("branch3x3_2b", branch3x3_2b);
register_module("branch3x3dbl_1", branch3x3dbl_1);
register_module("branch3x3dbl_2", branch3x3dbl_2);
register_module("branch3x3dbl_3a", branch3x3dbl_3a);
register_module("branch3x3dbl_3b", branch3x3dbl_3b);
register_module("branch_pool", branch_pool);
}
torch::Tensor InceptionEImpl::forward(const torch::Tensor& x) {
auto branch1x1 = this->branch1x1->forward(x);
auto branch3x3 = this->branch3x3_1->forward(x);
branch3x3 = torch::cat(
{
this->branch3x3_2a->forward(branch3x3),
this->branch3x3_2b->forward(branch3x3),
},
1);
auto branch3x3dbl = this->branch3x3dbl_1->forward(x);
branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl);
branch3x3dbl = torch::cat(
{this->branch3x3dbl_3a->forward(branch3x3dbl),
this->branch3x3dbl_3b->forward(branch3x3dbl)},
1);
auto branch_pool = torch::avg_pool2d(x, 3, 1, 1);
branch_pool = this->branch_pool->forward(branch_pool);
return torch::cat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, 1);
}
InceptionAuxImpl::InceptionAuxImpl(int64_t in_channels, int64_t num_classes)
: conv0(BasicConv2d(Options(in_channels, 128, 1))),
conv1(BasicConv2d(Options(128, 768, 5), 0.01)),
fc(768, num_classes) {
torch::nn::init::normal_(
fc->weight,
0,
0.001); // Note: used instead of truncated normal initialization
register_module("conv0", conv0);
register_module("conv1", conv1);
register_module("fc", fc);
}
torch::Tensor InceptionAuxImpl::forward(torch::Tensor x) {
// N x 768 x 17 x 17
x = torch::avg_pool2d(x, 5, 3);
// N x 768 x 5 x 5
x = conv0->forward(x);
// N x 128 x 5 x 5
x = conv1->forward(x);
// N x 768 x 1 x 1
x = torch::adaptive_avg_pool2d(x, {1, 1});
// N x 768 x 1 x 1
x = x.view({x.size(0), -1});
// N x 768
x = fc->forward(x);
// N x 1000 (num_classes)
return x;
}
} // namespace _inceptionimpl
InceptionV3Impl::InceptionV3Impl(
int64_t num_classes,
bool aux_logits,
bool transform_input)
: aux_logits(aux_logits), transform_input(transform_input) {
Conv2d_1a_3x3 = _inceptionimpl::BasicConv2d(Options(3, 32, 3).stride(2));
Conv2d_2a_3x3 = _inceptionimpl::BasicConv2d(Options(32, 32, 3));
Conv2d_2b_3x3 = _inceptionimpl::BasicConv2d(Options(32, 64, 3).padding(1));
Conv2d_3b_1x1 = _inceptionimpl::BasicConv2d(Options(64, 80, 1));
Conv2d_4a_3x3 = _inceptionimpl::BasicConv2d(Options(80, 192, 3));
Mixed_5b = _inceptionimpl::InceptionA(192, 32);
Mixed_5c = _inceptionimpl::InceptionA(256, 64);
Mixed_5d = _inceptionimpl::InceptionA(288, 64);
Mixed_6a = _inceptionimpl::InceptionB(288);
Mixed_6b = _inceptionimpl::InceptionC(768, 128);
Mixed_6c = _inceptionimpl::InceptionC(768, 160);
Mixed_6d = _inceptionimpl::InceptionC(768, 160);
Mixed_6e = _inceptionimpl::InceptionC(768, 192);
if (aux_logits)
AuxLogits = _inceptionimpl::InceptionAux(768, num_classes);
Mixed_7a = _inceptionimpl::InceptionD(768);
Mixed_7b = _inceptionimpl::InceptionE(1280);
Mixed_7c = _inceptionimpl::InceptionE(2048);
fc = torch::nn::Linear(2048, num_classes);
torch::nn::init::normal_(
fc->weight,
0,
0.1); // Note: used instead of truncated normal initialization
register_module("Conv2d_1a_3x3", Conv2d_1a_3x3);
register_module("Conv2d_2a_3x3", Conv2d_2a_3x3);
register_module("Conv2d_2b_3x3", Conv2d_2b_3x3);
register_module("Conv2d_3b_1x1", Conv2d_3b_1x1);
register_module("Conv2d_4a_3x3", Conv2d_4a_3x3);
register_module("Mixed_5b", Mixed_5b);
register_module("Mixed_5c", Mixed_5c);
register_module("Mixed_5d", Mixed_5d);
register_module("Mixed_6a", Mixed_6a);
register_module("Mixed_6b", Mixed_6b);
register_module("Mixed_6c", Mixed_6c);
register_module("Mixed_6d", Mixed_6d);
register_module("Mixed_6e", Mixed_6e);
if (!AuxLogits.is_empty())
register_module("AuxLogits", AuxLogits);
register_module("Mixed_7a", Mixed_7a);
register_module("Mixed_7b", Mixed_7b);
register_module("Mixed_7c", Mixed_7c);
register_module("fc", fc);
modelsimpl::deprecation_warning();
}
InceptionV3Output InceptionV3Impl::forward(torch::Tensor x) {
if (transform_input) {
auto x_ch0 = torch::unsqueeze(x.select(1, 0), 1) * (0.229 / 0.5) +
(0.485 - 0.5) / 0.5;
auto x_ch1 = torch::unsqueeze(x.select(1, 1), 1) * (0.224 / 0.5) +
(0.456 - 0.5) / 0.5;
auto x_ch2 = torch::unsqueeze(x.select(1, 2), 1) * (0.225 / 0.5) +
(0.406 - 0.5) / 0.5;
x = torch::cat({x_ch0, x_ch1, x_ch2}, 1);
}
// N x 3 x 299 x 299
x = Conv2d_1a_3x3->forward(x);
// N x 32 x 149 x 149
x = Conv2d_2a_3x3->forward(x);
// N x 32 x 147 x 147
x = Conv2d_2b_3x3->forward(x);
// N x 64 x 147 x 147
x = torch::max_pool2d(x, 3, 2);
// N x 64 x 73 x 73
x = Conv2d_3b_1x1->forward(x);
// N x 80 x 73 x 73
x = Conv2d_4a_3x3->forward(x);
// N x 192 x 71 x 71
x = torch::max_pool2d(x, 3, 2);
// N x 192 x 35 x 35
x = Mixed_5b->forward(x);
// N x 256 x 35 x 35
x = Mixed_5c->forward(x);
// N x 288 x 35 x 35
x = Mixed_5d->forward(x);
// N x 288 x 35 x 35
x = Mixed_6a->forward(x);
// N x 768 x 17 x 17
x = Mixed_6b->forward(x);
// N x 768 x 17 x 17
x = Mixed_6c->forward(x);
// N x 768 x 17 x 17
x = Mixed_6d->forward(x);
// N x 768 x 17 x 17
x = Mixed_6e->forward(x);
// N x 768 x 17 x 17
torch::Tensor aux;
if (is_training() && aux_logits)
aux = AuxLogits->forward(x);
// N x 768 x 17 x 17
x = Mixed_7a->forward(x);
// N x 1280 x 8 x 8
x = Mixed_7b->forward(x);
// N x 2048 x 8 x 8
x = Mixed_7c->forward(x);
// N x 2048 x 8 x 8
x = torch::adaptive_avg_pool2d(x, {1, 1});
// N x 2048 x 1 x 1
x = torch::dropout(x, 0.5, is_training());
// N x 2048 x 1 x 1
x = x.view({x.size(0), -1});
// N x 2048
x = fc->forward(x);
// N x 1000 (num_classes)
if (is_training() && aux_logits)
return {x, aux};
return {x, {}};
}
// namespace _inceptionimpl
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
namespace vision {
namespace models {
namespace _inceptionimpl {
struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};
explicit BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev = 0.1);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(BasicConv2d);
struct VISION_API InceptionAImpl : torch::nn::Module {
BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
branch3x3dbl_2, branch3x3dbl_3, branch_pool;
InceptionAImpl(int64_t in_channels, int64_t pool_features);
torch::Tensor forward(const torch::Tensor& x);
};
struct VISION_API InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;
explicit InceptionBImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x);
};
struct VISION_API InceptionCImpl : torch::nn::Module {
BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
branch_pool{nullptr};
InceptionCImpl(int64_t in_channels, int64_t channels_7x7);
torch::Tensor forward(const torch::Tensor& x);
};
struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4;
explicit InceptionDImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x);
};
struct VISION_API InceptionEImpl : torch::nn::Module {
BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool;
explicit InceptionEImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x);
};
struct VISION_API InceptionAuxImpl : torch::nn::Module {
BasicConv2d conv0;
BasicConv2d conv1;
torch::nn::Linear fc;
InceptionAuxImpl(int64_t in_channels, int64_t num_classes);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(InceptionA);
TORCH_MODULE(InceptionB);
TORCH_MODULE(InceptionC);
TORCH_MODULE(InceptionD);
TORCH_MODULE(InceptionE);
TORCH_MODULE(InceptionAux);
} // namespace _inceptionimpl
struct VISION_API InceptionV3Output {
torch::Tensor output;
torch::Tensor aux;
};
// Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567>
struct VISION_API InceptionV3Impl : torch::nn::Module {
bool aux_logits, transform_input;
_inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr};
_inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr},
Mixed_5d{nullptr};
_inceptionimpl::InceptionB Mixed_6a{nullptr};
_inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr},
Mixed_6d{nullptr}, Mixed_6e{nullptr};
_inceptionimpl::InceptionD Mixed_7a{nullptr};
_inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr};
torch::nn::Linear fc{nullptr};
_inceptionimpl::InceptionAux AuxLogits{nullptr};
explicit InceptionV3Impl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false);
InceptionV3Output forward(torch::Tensor x);
};
TORCH_MODULE(InceptionV3);
} // namespace models
} // namespace vision
#include "mnasnet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
struct MNASNetInvertedResidualImpl : torch::nn::Module {
bool apply_residual;
torch::nn::Sequential layers;
MNASNetInvertedResidualImpl(
int64_t input,
int64_t output,
int64_t kernel,
int64_t stride,
double expansion_factor,
double bn_momentum = 0.1) {
TORCH_CHECK(stride == 1 || stride == 2);
TORCH_CHECK(kernel == 3 || kernel == 5);
auto mid = int64_t(input * expansion_factor);
apply_residual = input == output && stride == 1;
layers->push_back(torch::nn::Conv2d(Options(input, mid, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
layers->push_back(
torch::nn::Conv2d(torch::nn::Conv2d(Options(mid, mid, kernel)
.padding(kernel / 2)
.stride(stride)
.groups(mid)
.bias(false))));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
layers->push_back(torch::nn::Conv2d(Options(mid, output, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(output).momentum(bn_momentum)));
register_module("layers", layers);
}
torch::Tensor forward(torch::Tensor x) {
if (apply_residual)
return layers->forward(x) + x;
return layers->forward(x);
}
};
TORCH_MODULE(MNASNetInvertedResidual);
struct StackSequentailImpl : torch::nn::SequentialImpl {
using SequentialImpl::SequentialImpl;
torch::Tensor forward(torch::Tensor x) {
return SequentialImpl::forward(x);
}
};
TORCH_MODULE(StackSequentail);
StackSequentail stack(
int64_t input,
int64_t output,
int64_t kernel,
int64_t stride,
double exp_factor,
int64_t repeats,
double bn_momentum) {
TORCH_CHECK(repeats >= 1);
StackSequentail seq;
seq->push_back(MNASNetInvertedResidual(
input, output, kernel, stride, exp_factor, bn_momentum));
for (int64_t i = 1; i < repeats; ++i)
seq->push_back(MNASNetInvertedResidual(
output, output, kernel, 1, exp_factor, bn_momentum));
return seq;
}
int64_t round_to_multiple_of(
int64_t val,
int64_t divisor,
double round_up_bias = .9) {
TORCH_CHECK(0.0 < round_up_bias && round_up_bias < 1.0);
auto new_val = std::max(divisor, (val + divisor / 2) / divisor * divisor);
return new_val >= round_up_bias * val ? new_val : new_val + divisor;
}
std::vector<int64_t> scale_depths(std::vector<int64_t> depths, double alpha) {
std::vector<int64_t> data(depths.size());
for (size_t i = 0; i < data.size(); ++i) {
data[i] = round_to_multiple_of(int64_t(depths[i] * alpha), 8);
}
return data;
}
void MNASNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::kFanOut, torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
torch::nn::init::normal_(M->weight, 0, 0.01);
torch::nn::init::zeros_(M->bias);
}
}
}
#define BN_MOMENTUM 1 - 0.9997
MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
auto depths = scale_depths({24, 40, 80, 96, 192, 320}, alpha);
layers->push_back(
torch::nn::Conv2d(Options(3, 32, 3).padding(1).stride(2).bias(false)));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d(
Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false)));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(
torch::nn::Conv2d(Options(32, 16, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM)));
layers->push_back(stack(16, depths[0], 3, 2, 3, 3, BN_MOMENTUM));
layers->push_back(stack(depths[0], depths[1], 5, 2, 3, 3, BN_MOMENTUM));
layers->push_back(stack(depths[1], depths[2], 5, 2, 6, 3, BN_MOMENTUM));
layers->push_back(stack(depths[2], depths[3], 3, 1, 6, 2, BN_MOMENTUM));
layers->push_back(stack(depths[3], depths[4], 5, 2, 6, 4, BN_MOMENTUM));
layers->push_back(stack(depths[4], depths[5], 3, 1, 6, 1, BN_MOMENTUM));
layers->push_back(torch::nn::Conv2d(
Options(depths[5], 1280, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
classifier = torch::nn::Sequential(
torch::nn::Dropout(dropout), torch::nn::Linear(1280, num_classes));
register_module("layers", layers);
register_module("classifier", classifier);
_initialize_weights();
modelsimpl::deprecation_warning();
}
torch::Tensor MNASNetImpl::forward(torch::Tensor x) {
x = layers->forward(x);
x = x.mean({2, 3});
return classifier->forward(x);
}
MNASNet0_5Impl::MNASNet0_5Impl(int64_t num_classes, double dropout)
: MNASNetImpl(.5, num_classes, dropout) {}
MNASNet0_75Impl::MNASNet0_75Impl(int64_t num_classes, double dropout)
: MNASNetImpl(.75, num_classes, dropout) {}
MNASNet1_0Impl::MNASNet1_0Impl(int64_t num_classes, double dropout)
: MNASNetImpl(1, num_classes, dropout) {}
MNASNet1_3Impl::MNASNet1_3Impl(int64_t num_classes, double dropout)
: MNASNetImpl(1.3, num_classes, dropout) {}
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
namespace vision {
namespace models {
struct VISION_API MNASNetImpl : torch::nn::Module {
torch::nn::Sequential layers, classifier;
void _initialize_weights();
explicit MNASNetImpl(
double alpha,
int64_t num_classes = 1000,
double dropout = .2);
torch::Tensor forward(torch::Tensor x);
};
struct MNASNet0_5Impl : MNASNetImpl {
explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
};
struct MNASNet0_75Impl : MNASNetImpl {
explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
};
struct MNASNet1_0Impl : MNASNetImpl {
explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
};
struct MNASNet1_3Impl : MNASNetImpl {
explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
};
TORCH_MODULE(MNASNet);
TORCH_MODULE(MNASNet0_5);
TORCH_MODULE(MNASNet0_75);
TORCH_MODULE(MNASNet1_0);
TORCH_MODULE(MNASNet1_3);
} // namespace models
} // namespace vision
#include "mobilenet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
int64_t make_divisible(
double value,
int64_t divisor,
c10::optional<int64_t> min_value = {}) {
if (!min_value.has_value())
min_value = divisor;
auto new_value = std::max(
min_value.value(), (int64_t(value + divisor / 2) / divisor) * divisor);
if (new_value < .9 * value)
new_value += divisor;
return new_value;
}
struct ConvBNReLUImpl : torch::nn::SequentialImpl {
ConvBNReLUImpl(
int64_t in_planes,
int64_t out_planes,
int64_t kernel_size = 3,
int64_t stride = 1,
int64_t groups = 1) {
auto padding = (kernel_size - 1) / 2;
push_back(torch::nn::Conv2d(Options(in_planes, out_planes, kernel_size)
.stride(stride)
.padding(padding)
.groups(groups)
.bias(false)));
push_back(torch::nn::BatchNorm2d(out_planes));
push_back(torch::nn::Functional(modelsimpl::relu6_));
}
torch::Tensor forward(torch::Tensor x) {
return torch::nn::SequentialImpl::forward(x);
}
};
TORCH_MODULE(ConvBNReLU);
struct MobileNetInvertedResidualImpl : torch::nn::Module {
int64_t stride;
bool use_res_connect;
torch::nn::Sequential conv;
MobileNetInvertedResidualImpl(
int64_t input,
int64_t output,
int64_t stride,
double expand_ratio)
: stride(stride), use_res_connect(stride == 1 && input == output) {
auto double_compare = [](double a, double b) {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};
TORCH_CHECK(stride == 1 || stride == 2);
auto hidden_dim = int64_t(std::round(input * expand_ratio));
if (!double_compare(expand_ratio, 1))
conv->push_back(ConvBNReLU(input, hidden_dim, 1));
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
conv->push_back(torch::nn::Conv2d(
Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
conv->push_back(torch::nn::BatchNorm2d(output));
register_module("conv", conv);
}
torch::Tensor forward(torch::Tensor x) {
if (use_res_connect)
return x + conv->forward(x);
return conv->forward(x);
}
};
TORCH_MODULE(MobileNetInvertedResidual);
MobileNetV2Impl::MobileNetV2Impl(
int64_t num_classes,
double width_mult,
std::vector<std::vector<int64_t>> inverted_residual_settings,
int64_t round_nearest) {
using Block = MobileNetInvertedResidual;
int64_t input_channel = 32;
int64_t last_channel = 1280;
if (inverted_residual_settings.empty())
inverted_residual_settings = {
// t, c, n, s
{1, 16, 1, 1},
{6, 24, 2, 2},
{6, 32, 3, 2},
{6, 64, 4, 2},
{6, 96, 3, 1},
{6, 160, 3, 2},
{6, 320, 1, 1},
};
TORCH_CHECK(
inverted_residual_settings[0].size() == 4,
"inverted_residual_settings should contain 4-element vectors");
input_channel = make_divisible(input_channel * width_mult, round_nearest);
this->last_channel =
make_divisible(last_channel * std::max(1.0, width_mult), round_nearest);
features->push_back(ConvBNReLU(3, input_channel, 3, 2));
for (auto setting : inverted_residual_settings) {
auto output_channel =
make_divisible(setting[1] * width_mult, round_nearest);
for (int64_t i = 0; i < setting[2]; ++i) {
auto stride = i == 0 ? setting[3] : 1;
features->push_back(
Block(input_channel, output_channel, stride, setting[0]));
input_channel = output_channel;
}
}
features->push_back(ConvBNReLU(input_channel, this->last_channel, 1));
classifier->push_back(torch::nn::Dropout(0.2));
classifier->push_back(torch::nn::Linear(this->last_channel, num_classes));
register_module("features", features);
register_module("classifier", classifier);
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut);
if (M->options.bias())
torch::nn::init::zeros_(M->bias);
} else if (
auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
torch::nn::init::normal_(M->weight, 0, 0.01);
torch::nn::init::zeros_(M->bias);
}
}
modelsimpl::deprecation_warning();
}
torch::Tensor MobileNetV2Impl::forward(at::Tensor x) {
x = features->forward(x);
x = x.mean({2, 3});
x = classifier->forward(x);
return x;
}
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
namespace vision {
namespace models {
struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel;
torch::nn::Sequential features, classifier;
explicit MobileNetV2Impl(
int64_t num_classes = 1000,
double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
int64_t round_nearest = 8);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(MobileNetV2);
} // namespace models
} // namespace vision
#pragma once
#include "alexnet.h"
#include "densenet.h"
#include "googlenet.h"
#include "inception.h"
#include "mnasnet.h"
#include "mobilenet.h"
#include "resnet.h"
#include "shufflenetv2.h"
#include "squeezenet.h"
#include "vgg.h"
#pragma once
#include <torch/nn.h>
namespace vision {
namespace models {
namespace modelsimpl {
// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in
// torch::nn::Fuctional don't work. so keeping these for now
inline torch::Tensor& relu_(const torch::Tensor& x) {
return x.relu_();
}
inline torch::Tensor& relu6_(const torch::Tensor& x) {
return x.clamp_(0, 6);
}
inline torch::Tensor adaptive_avg_pool2d(
const torch::Tensor& x,
torch::ExpandingArray<2> output_size) {
return torch::adaptive_avg_pool2d(x, output_size);
}
inline torch::Tensor max_pool2d(
const torch::Tensor& x,
torch::ExpandingArray<2> kernel_size,
torch::ExpandingArray<2> stride) {
return torch::max_pool2d(x, kernel_size, stride);
}
inline bool double_compare(double a, double b) {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};
inline void deprecation_warning() {
TORCH_WARN_ONCE(
"The vision::models namespace is deprecated since 0.12 and will be "
"removed in 0.14. We recommend using Torch Script instead: "
"https://pytorch.org/tutorials/advanced/cpp_export.html");
}
} // namespace modelsimpl
} // namespace models
} // namespace vision
#include "resnet.h"
namespace vision {
namespace models {
namespace _resnetimpl {
torch::nn::Conv2d conv3x3(
int64_t in,
int64_t out,
int64_t stride,
int64_t groups) {
torch::nn::Conv2dOptions O(in, out, 3);
O.padding(1).stride(stride).groups(groups).bias(false);
return torch::nn::Conv2d(O);
}
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) {
torch::nn::Conv2dOptions O(in, out, 1);
O.stride(stride).bias(false);
return torch::nn::Conv2d(O);
}
int BasicBlock::expansion = 1;
int Bottleneck::expansion = 4;
BasicBlock::BasicBlock(
int64_t inplanes,
int64_t planes,
int64_t stride,
const torch::nn::Sequential& downsample,
int64_t groups,
int64_t base_width)
: stride(stride), downsample(downsample) {
TORCH_CHECK(
groups == 1 && base_width == 64,
"BasicBlock only supports groups=1 and base_width=64");
// Both conv1 and downsample layers downsample the input when stride != 1
conv1 = conv3x3(inplanes, planes, stride);
conv2 = conv3x3(planes, planes);
bn1 = torch::nn::BatchNorm2d(planes);
bn2 = torch::nn::BatchNorm2d(planes);
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("bn1", bn1);
register_module("bn2", bn2);
if (!downsample.is_empty())
register_module("downsample", this->downsample);
}
Bottleneck::Bottleneck(
int64_t inplanes,
int64_t planes,
int64_t stride,
const torch::nn::Sequential& downsample,
int64_t groups,
int64_t base_width)
: stride(stride), downsample(downsample) {
auto width = int64_t(planes * (base_width / 64.)) * groups;
// Both conv2 and downsample layers downsample the input when stride != 1
conv1 = conv1x1(inplanes, width);
conv2 = conv3x3(width, width, stride, groups);
conv3 = conv1x1(width, planes * expansion);
bn1 = torch::nn::BatchNorm2d(width);
bn2 = torch::nn::BatchNorm2d(width);
bn3 = torch::nn::BatchNorm2d(planes * expansion);
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("bn1", bn1);
register_module("bn2", bn2);
register_module("bn3", bn3);
if (!downsample.is_empty())
register_module("downsample", this->downsample);
}
torch::Tensor Bottleneck::forward(torch::Tensor X) {
auto identity = X;
auto out = conv1->forward(X);
out = bn1->forward(out).relu_();
out = conv2->forward(out);
out = bn2->forward(out).relu_();
out = conv3->forward(out);
out = bn3->forward(out);
if (!downsample.is_empty())
identity = downsample->forward(X);
out += identity;
return out.relu_();
}
torch::Tensor BasicBlock::forward(torch::Tensor x) {
auto identity = x;
auto out = conv1->forward(x);
out = bn1->forward(out).relu_();
out = conv2->forward(out);
out = bn2->forward(out);
if (!downsample.is_empty())
identity = downsample->forward(x);
out += identity;
return out.relu_();
}
} // namespace _resnetimpl
ResNet18Impl::ResNet18Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({2, 2, 2, 2}, num_classes, zero_init_residual) {}
ResNet34Impl::ResNet34Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {}
ResNet50Impl::ResNet50Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {}
ResNet101Impl::ResNet101Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual) {}
ResNet152Impl::ResNet152Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 8, 36, 3}, num_classes, zero_init_residual) {}
ResNext50_32x4dImpl::ResNext50_32x4dImpl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 32, 4) {}
ResNext101_32x8dImpl::ResNext101_32x8dImpl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {}
WideResNet50_2Impl::WideResNet50_2Impl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}
WideResNet101_2Impl::WideResNet101_2Impl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}
} // namespace models
} // namespace vision
#pragma once
#include <torch/nn.h>
#include "../macros.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
template <typename Block>
struct ResNetImpl;
namespace _resnetimpl {
// 3x3 convolution with padding
torch::nn::Conv2d conv3x3(
int64_t in,
int64_t out,
int64_t stride = 1,
int64_t groups = 1);
// 1x1 convolution
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1);
struct VISION_API BasicBlock : torch::nn::Module {
template <typename Block>
friend struct vision::models::ResNetImpl;
int64_t stride;
torch::nn::Sequential downsample;
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};
static int expansion;
BasicBlock(
int64_t inplanes,
int64_t planes,
int64_t stride = 1,
const torch::nn::Sequential& downsample = nullptr,
int64_t groups = 1,
int64_t base_width = 64);
torch::Tensor forward(torch::Tensor x);
};
struct VISION_API Bottleneck : torch::nn::Module {
template <typename Block>
friend struct vision::models::ResNetImpl;
int64_t stride;
torch::nn::Sequential downsample;
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
static int expansion;
Bottleneck(
int64_t inplanes,
int64_t planes,
int64_t stride = 1,
const torch::nn::Sequential& downsample = nullptr,
int64_t groups = 1,
int64_t base_width = 64);
torch::Tensor forward(torch::Tensor X);
};
} // namespace _resnetimpl
template <typename Block>
struct ResNetImpl : torch::nn::Module {
int64_t groups, base_width, inplanes;
torch::nn::Conv2d conv1;
torch::nn::BatchNorm2d bn1;
torch::nn::Sequential layer1, layer2, layer3, layer4;
torch::nn::Linear fc;
torch::nn::Sequential _make_layer(
int64_t planes,
int64_t blocks,
int64_t stride = 1);
explicit ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes = 1000,
bool zero_init_residual = false,
int64_t groups = 1,
int64_t width_per_group = 64);
torch::Tensor forward(torch::Tensor X);
};
template <typename Block>
torch::nn::Sequential ResNetImpl<Block>::_make_layer(
int64_t planes,
int64_t blocks,
int64_t stride) {
torch::nn::Sequential downsample = nullptr;
if (stride != 1 || inplanes != planes * Block::expansion) {
downsample = torch::nn::Sequential(
_resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
torch::nn::BatchNorm2d(planes * Block::expansion));
}
torch::nn::Sequential layers;
layers->push_back(
Block(inplanes, planes, stride, downsample, groups, base_width));
inplanes = planes * Block::expansion;
for (int i = 1; i < blocks; ++i)
layers->push_back(Block(inplanes, planes, 1, nullptr, groups, base_width));
return layers;
}
template <typename Block>
ResNetImpl<Block>::ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes,
bool zero_init_residual,
int64_t groups,
int64_t width_per_group)
: groups(groups),
base_width(width_per_group),
inplanes(64),
conv1(
torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false)),
bn1(64),
layer1(_make_layer(64, layers[0])),
layer2(_make_layer(128, layers[1], 2)),
layer3(_make_layer(256, layers[2], 2)),
layer4(_make_layer(512, layers[3], 2)),
fc(512 * Block::expansion, num_classes) {
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("fc", fc);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
}
}
// Zero-initialize the last BN in each residual branch, so that the residual
// branch starts with zeros, and each residual block behaves like an
// identity. This improves the model by 0.2~0.3% according to
// https://arxiv.org/abs/1706.02677
if (zero_init_residual)
for (auto& module : modules(/*include_self=*/false)) {
if (auto* M = dynamic_cast<_resnetimpl::Bottleneck*>(module.get()))
torch::nn::init::constant_(M->bn3->weight, 0);
else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get()))
torch::nn::init::constant_(M->bn2->weight, 0);
}
modelsimpl::deprecation_warning();
}
template <typename Block>
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
x = conv1->forward(x);
x = bn1->forward(x).relu_();
x = torch::max_pool2d(x, 3, 2, 1);
x = layer1->forward(x);
x = layer2->forward(x);
x = layer3->forward(x);
x = layer4->forward(x);
x = torch::adaptive_avg_pool2d(x, {1, 1});
x = x.reshape({x.size(0), -1});
x = fc->forward(x);
return x;
}
struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
explicit ResNet18Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
explicit ResNet34Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit ResNet50Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit ResNet101Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit ResNet152Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit ResNext50_32x4dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit ResNext101_32x8dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit WideResNet50_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
explicit WideResNet101_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
template <typename Block>
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
};
TORCH_MODULE(ResNet18);
TORCH_MODULE(ResNet34);
TORCH_MODULE(ResNet50);
TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d);
TORCH_MODULE(WideResNet50_2);
TORCH_MODULE(WideResNet101_2);
} // namespace models
} // namespace vision
#include "shufflenetv2.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) {
auto shape = x.sizes();
auto batchsize = shape[0];
auto num_channels = shape[1];
auto height = shape[2];
auto width = shape[3];
auto channels_per_group = num_channels / groups;
x = x.view({batchsize, groups, channels_per_group, height, width});
x = torch::transpose(x, 1, 2).contiguous();
x = x.view({batchsize, -1, height, width});
return x;
}
torch::nn::Conv2d conv11(int64_t input, int64_t output) {
Options opts(input, output, 1);
opts = opts.stride(1).padding(0).bias(false);
return torch::nn::Conv2d(opts);
}
torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) {
Options opts(input, output, 3);
opts = opts.stride(stride).padding(1).bias(false).groups(input);
return torch::nn::Conv2d(opts);
}
struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
int64_t stride;
torch::nn::Sequential branch1{nullptr}, branch2{nullptr};
ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride)
: stride(stride) {
TORCH_CHECK(stride >= 1 && stride <= 3, "illegal stride value");
auto branch_features = oup / 2;
TORCH_CHECK(stride != 1 || inp == branch_features << 1);
if (stride > 1) {
branch1 = torch::nn::Sequential(
conv33(inp, inp, stride),
torch::nn::BatchNorm2d(inp),
conv11(inp, branch_features),
torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_));
}
branch2 = torch::nn::Sequential(
conv11(stride > 1 ? inp : branch_features, branch_features),
torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_),
conv33(branch_features, branch_features, stride),
torch::nn::BatchNorm2d(branch_features),
conv11(branch_features, branch_features),
torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_));
if (!branch1.is_empty())
register_module("branch1", branch1);
register_module("branch2", branch2);
}
torch::Tensor forward(torch::Tensor x) {
torch::Tensor out;
if (stride == 1) {
auto chunks = x.chunk(2, 1);
out = torch::cat({chunks[0], branch2->forward(chunks[1])}, 1);
} else
out = torch::cat({branch1->forward(x), branch2->forward(x)}, 1);
out = ::vision::models::channel_shuffle(out, 2);
return out;
}
};
TORCH_MODULE(ShuffleNetV2InvertedResidual);
ShuffleNetV2Impl::ShuffleNetV2Impl(
const std::vector<int64_t>& stage_repeats,
const std::vector<int64_t>& stage_out_channels,
int64_t num_classes) {
TORCH_CHECK(
stage_repeats.size() == 3,
"expected stage_repeats as vector of 3 positive ints");
TORCH_CHECK(
stage_out_channels.size() == 5,
"expected stage_out_channels as vector of 5 positive ints");
_stage_out_channels = stage_out_channels;
int64_t input_channels = 3;
auto output_channels = _stage_out_channels[0];
conv1 = torch::nn::Sequential(
torch::nn::Conv2d(Options(input_channels, output_channels, 3)
.stride(2)
.padding(1)
.bias(false)),
torch::nn::BatchNorm2d(output_channels),
torch::nn::Functional(modelsimpl::relu_));
input_channels = output_channels;
std::vector<torch::nn::Sequential> stages = {stage2, stage3, stage4};
for (size_t i = 0; i < stages.size(); ++i) {
auto& seq = stages[i];
auto repeats = stage_repeats[i];
auto output_channels = _stage_out_channels[i + 1];
seq->push_back(
ShuffleNetV2InvertedResidual(input_channels, output_channels, 2));
for (size_t j = 0; j < size_t(repeats - 1); ++j)
seq->push_back(
ShuffleNetV2InvertedResidual(output_channels, output_channels, 1));
input_channels = output_channels;
}
output_channels = _stage_out_channels.back();
conv5 = torch::nn::Sequential(
torch::nn::Conv2d(Options(input_channels, output_channels, 1)
.stride(1)
.padding(0)
.bias(false)),
torch::nn::BatchNorm2d(output_channels),
torch::nn::Functional(modelsimpl::relu_));
fc = torch::nn::Linear(output_channels, num_classes);
register_module("conv1", conv1);
register_module("stage2", stage2);
register_module("stage3", stage3);
register_module("stage4", stage4);
register_module("conv2", conv5);
register_module("fc", fc);
modelsimpl::deprecation_warning();
}
torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) {
x = conv1->forward(x);
x = torch::max_pool2d(x, 3, 2, 1);
x = stage2->forward(x);
x = stage3->forward(x);
x = stage4->forward(x);
x = conv5->forward(x);
x = x.mean({2, 3});
x = fc->forward(x);
return x;
}
ShuffleNetV2_x0_5Impl::ShuffleNetV2_x0_5Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 48, 96, 192, 1024}, num_classes) {}
ShuffleNetV2_x1_0Impl::ShuffleNetV2_x1_0Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 116, 232, 464, 1024}, num_classes) {}
ShuffleNetV2_x1_5Impl::ShuffleNetV2_x1_5Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 176, 352, 704, 1024}, num_classes) {}
ShuffleNetV2_x2_0Impl::ShuffleNetV2_x2_0Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 244, 488, 976, 2048}, num_classes) {}
} // namespace models
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment