Commit b5db97b4 authored by Shahriar's avatar Shahriar Committed by Francisco Massa
Browse files

C++ Models (#728)

* Added the existing code

* Added squeezenet and fixed some stuff in the other models

* Wrote DenseNet and a part of InceptionV3

Going to clean and check all of the models and finish inception

* Fixed some errors in the models

Next step is writing inception and comparing with python code again.

* Completed inception and changed models directory

* Fixed and wrote some stuff

* fixed maxpoool2d and avgpool2d and adaptiveavgpool2d

* Fixed a few stuff

Moved cmakelists to root and changed the namespace to vision and wrote weight initialization in inception

* Added models namespace and changed cmakelists

the project is now installable

* Removed some comments

* Changed style to pytorch style, added some comments and fixed some minor errors

* Removed truncated normal init

* Changed classes to structs and fixed a few errors

* Replaced modelsimpl structs with functional wherever possible

* Changed adaptive average pool from struct to function

* Wrote a max_pool2d wrapper and added some comments

* Replaced xavier init with kaiming init

* Fixed an error in kaiming inits

* Added model conversion and tests

* Fixed a typo in alexnet and removed tests from cmake

* Made an extension of tests and added module names to Densenet

* Added python tests

* Added MobileNet and GoogLeNet models

* Added tests and conversions for new models and fixed a few errors

* Updated Alexnet ad VGG

* Updated Densenet, Squeezenet and Inception

* Added ResNexts and their conversions

* Added tests for ResNexts

* Wrote tools nessesary to write ShuffleNet

* Added ShuffleNetV2

* Fixed some errors in ShuffleNetV2

* Added conversions for shufflenetv2

* Fixed the errors in test_models.cpp

* Updated setup.py

* Fixed flake8 error on test_cpp_models.py

* Changed view to reshape in forward of ResNet

* Updated ShuffleNetV2

* Split extensions to tests and ops

* Fixed test extension

* Fixed image path in test_cpp_models.py

* Fixed image path in test_cpp_models.py

* Fixed a few things in test_cpp_models.py

* Put the test models in evaluation mode

* Fixed registering error in GoogLeNet

* Updated setup.py

* write test_cpp_models.py with unittest

* Fixed a problem with pytest in test_cpp_models.py

* Fixed a lint problem
parent 394de98e
cmake_minimum_required(VERSION 2.8)
project(torchvision)
set(CMAKE_CXX_STANDARD 11)
find_package(Torch REQUIRED)
file(GLOB_RECURSE HEADERS torchvision/csrc/vision.h)
file(GLOB_RECURSE MODELS_HEADERS torchvision/csrc/models/*.h)
file(GLOB_RECURSE MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)
add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES})
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
add_executable(convertmodels torchvision/csrc/convert_models/convert_models.cpp)
target_link_libraries(convertmodels "${PROJECT_NAME}")
target_link_libraries(convertmodels "${TORCH_LIBRARIES}")
#add_executable(testmodels test/test_models.cpp)
#target_link_libraries(testmodels "${PROJECT_NAME}")
#target_link_libraries(testmodels "${TORCH_LIBRARIES}")
install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib)
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME})
install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models)
......@@ -89,6 +89,15 @@ def get_extensions():
sources = main_file + source_cpu
extension = CppExtension
test_dir = os.path.join(this_dir, 'test')
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
source_models = glob.glob(os.path.join(models_dir, '*.cpp'))
test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models
define_macros = []
extra_compile_args = {}
......@@ -109,6 +118,7 @@ def get_extensions():
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
tests_include_dirs = [test_dir, models_dir]
ext_modules = [
extension(
......@@ -117,6 +127,13 @@ def get_extensions():
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
extension(
'torchvision._C_tests',
tests,
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
......
import torch
import os
import unittest
from torchvision import models, transforms, _C_tests
from PIL import Image
import torchvision.transforms.functional as F
def process_model(model, tensor, func, name):
model.eval()
traced_script_module = torch.jit.trace(model, tensor)
traced_script_module.save("model.pt")
py_output = model.forward(tensor)
cpp_output = func("model.pt", tensor)
assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models'
def read_image1():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((224, 224))
x = F.to_tensor(image)
return x.view(1, 3, 224, 224)
def read_image2():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
image = Image.open(image_path)
image = image.resize((299, 299))
x = F.to_tensor(image)
x = x.view(1, 3, 299, 299)
return torch.cat([x, x], 0)
class Tester(unittest.TestCase):
pretrained = False
image = read_image1()
def test_alexnet(self):
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet')
def test_vgg11(self):
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11')
def test_vgg13(self):
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13')
def test_vgg16(self):
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16')
def test_vgg19(self):
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19')
def test_vgg11_bn(self):
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN')
def test_vgg13_bn(self):
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN')
def test_vgg16_bn(self):
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN')
def test_vgg19_bn(self):
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN')
def test_resnet18(self):
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18')
def test_resnet34(self):
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34')
def test_resnet50(self):
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50')
def test_resnet101(self):
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101')
def test_resnet152(self):
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152')
def test_resnext50_32x4d(self):
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d')
def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')
def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(self.pretrained), self.image,
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0')
def test_squeezenet1_1(self):
process_model(models.squeezenet1_1(self.pretrained), self.image,
_C_tests.forward_squeezenet1_1, 'Squeezenet1.1')
def test_densenet121(self):
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121')
def test_densenet169(self):
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169')
def test_densenet201(self):
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201')
def test_densenet161(self):
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161')
def test_mobilenet_v2(self):
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet')
def test_googlenet(self):
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet')
def test_inception_v3(self):
self.image = read_image2()
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3')
if __name__ == '__main__':
unittest.main()
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include "../torchvision/csrc/models/models.h"
using namespace vision::models;
template <typename Model>
torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) {
Model network;
torch::load(network, input_path);
network->eval();
return network->forward(x);
}
torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) {
return forward_model<AlexNet>(input_path, x);
}
torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11>(input_path, x);
}
torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13>(input_path, x);
}
torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16>(input_path, x);
}
torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19>(input_path, x);
}
torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG11BN>(input_path, x);
}
torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG13BN>(input_path, x);
}
torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG16BN>(input_path, x);
}
torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) {
return forward_model<VGG19BN>(input_path, x);
}
torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet18>(input_path, x);
}
torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet34>(input_path, x);
}
torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) {
return forward_model<ResNet50>(input_path, x);
}
torch::Tensor forward_resnet101(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet101>(input_path, x);
}
torch::Tensor forward_resnet152(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNet152>(input_path, x);
}
torch::Tensor forward_resnext50_32x4d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext50_32x4d>(input_path, x);
}
torch::Tensor forward_resnext101_32x8d(
const std::string& input_path,
torch::Tensor x) {
return forward_model<ResNext101_32x8d>(input_path, x);
}
torch::Tensor forward_squeezenet1_0(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_0>(input_path, x);
}
torch::Tensor forward_squeezenet1_1(
const std::string& input_path,
torch::Tensor x) {
return forward_model<SqueezeNet1_1>(input_path, x);
}
torch::Tensor forward_densenet121(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet121>(input_path, x);
}
torch::Tensor forward_densenet169(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet169>(input_path, x);
}
torch::Tensor forward_densenet201(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet201>(input_path, x);
}
torch::Tensor forward_densenet161(
const std::string& input_path,
torch::Tensor x) {
return forward_model<DenseNet161>(input_path, x);
}
torch::Tensor forward_mobilenetv2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<MobileNetV2>(input_path, x);
}
torch::Tensor forward_googlenet(
const std::string& input_path,
torch::Tensor x) {
GoogLeNet network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}
torch::Tensor forward_inceptionv3(
const std::string& input_path,
torch::Tensor x) {
InceptionV3 network;
torch::load(network, input_path);
network->eval();
return network->forward(x).output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_alexnet", &forward_alexnet, "forward_alexnet");
m.def("forward_vgg11", &forward_vgg11, "forward_vgg11");
m.def("forward_vgg13", &forward_vgg13, "forward_vgg13");
m.def("forward_vgg16", &forward_vgg16, "forward_vgg16");
m.def("forward_vgg19", &forward_vgg19, "forward_vgg19");
m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn");
m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn");
m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn");
m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn");
m.def("forward_resnet18", &forward_resnet18, "forward_resnet18");
m.def("forward_resnet34", &forward_resnet34, "forward_resnet34");
m.def("forward_resnet50", &forward_resnet50, "forward_resnet50");
m.def("forward_resnet101", &forward_resnet101, "forward_resnet101");
m.def("forward_resnet152", &forward_resnet152, "forward_resnet152");
m.def(
"forward_resnext50_32x4d",
&forward_resnext50_32x4d,
"forward_resnext50_32x4d");
m.def(
"forward_resnext101_32x8d",
&forward_resnext101_32x8d,
"forward_resnext101_32x8d");
m.def(
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
m.def(
"forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1");
m.def("forward_densenet121", &forward_densenet121, "forward_densenet121");
m.def("forward_densenet169", &forward_densenet169, "forward_densenet169");
m.def("forward_densenet201", &forward_densenet201, "forward_densenet201");
m.def("forward_densenet161", &forward_densenet161, "forward_densenet161");
m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2");
m.def("forward_googlenet", &forward_googlenet, "forward_googlenet");
m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3");
}
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include "../models/models.h"
using namespace vision::models;
template <typename Model>
void convert_and_save_model(
const std::string& input_path,
const std::string& output_path) {
Model network;
torch::load(network, input_path);
torch::save(network, output_path);
auto index = input_path.find("_python");
auto name = input_path.substr(0, index);
std::cout << "finished loading and saving " << name << std::endl;
}
int main(int argc, const char* argv[]) {
convert_and_save_model<AlexNet>("alexnet_python.pt", "alexnet_cpp.pt");
convert_and_save_model<VGG11>("vgg11_python.pt", "vgg11_cpp.pt");
convert_and_save_model<VGG13>("vgg13_python.pt", "vgg13_cpp.pt");
convert_and_save_model<VGG16>("vgg16_python.pt", "vgg16_cpp.pt");
convert_and_save_model<VGG19>("vgg19_python.pt", "vgg19_cpp.pt");
convert_and_save_model<VGG11BN>("vgg11bn_python.pt", "vgg11bn_cpp.pt");
convert_and_save_model<VGG13BN>("vgg13bn_python.pt", "vgg13bn_cpp.pt");
convert_and_save_model<VGG16BN>("vgg16bn_python.pt", "vgg16bn_cpp.pt");
convert_and_save_model<VGG19BN>("vgg19bn_python.pt", "vgg19bn_cpp.pt");
convert_and_save_model<ResNet18>("resnet18_python.pt", "resnet18_cpp.pt");
convert_and_save_model<ResNet34>("resnet34_python.pt", "resnet34_cpp.pt");
convert_and_save_model<ResNet50>("resnet50_python.pt", "resnet50_cpp.pt");
convert_and_save_model<ResNet101>("resnet101_python.pt", "resnet101_cpp.pt");
convert_and_save_model<ResNet152>("resnet152_python.pt", "resnet152_cpp.pt");
convert_and_save_model<ResNext50_32x4d>(
"resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt");
convert_and_save_model<ResNext101_32x8d>(
"resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt");
convert_and_save_model<SqueezeNet1_0>(
"squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt");
convert_and_save_model<SqueezeNet1_1>(
"squeezenet1_1_python.pt", "squeezenet1_1_cpp.pt");
convert_and_save_model<DenseNet121>(
"densenet121_python.pt", "densenet121_cpp.pt");
convert_and_save_model<DenseNet169>(
"densenet169_python.pt", "densenet169_cpp.pt");
convert_and_save_model<DenseNet201>(
"densenet201_python.pt", "densenet201_cpp.pt");
convert_and_save_model<DenseNet161>(
"densenet161_python.pt", "densenet161_cpp.pt");
convert_and_save_model<MobileNetV2>(
"mobilenetv2_python.pt", "mobilenetv2_cpp.pt");
convert_and_save_model<ShuffleNetV2_x0_5>(
"shufflenetv2_x0_5_python.pt", "shufflenetv2_x0_5_cpp.pt");
convert_and_save_model<ShuffleNetV2_x1_0>(
"shufflenetv2_x1_0_python.pt", "shufflenetv2_x1_0_cpp.pt");
convert_and_save_model<ShuffleNetV2_x1_5>(
"shufflenetv2_x1_5_python.pt", "shufflenetv2_x1_5_cpp.pt");
convert_and_save_model<ShuffleNetV2_x2_0>(
"shufflenetv2_x2_0_python.pt", "shufflenetv2_x2_0_cpp.pt");
convert_and_save_model<GoogLeNet>("googlenet_python.pt", "googlenet_cpp.pt");
convert_and_save_model<InceptionV3>(
"inceptionv3_python.pt", "inceptionv3_cpp.pt");
return 0;
}
#include "alexnet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
AlexNetImpl::AlexNetImpl(int64_t num_classes) {
features = torch::nn::Sequential(
torch::nn::Conv2d(
torch::nn::Conv2dOptions(3, 64, 11).stride(4).padding(2)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Functional(modelsimpl::max_pool2d, 3, 2),
torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 192, 5).padding(2)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Functional(modelsimpl::max_pool2d, 3, 2),
torch::nn::Conv2d(torch::nn::Conv2dOptions(192, 384, 3).padding(1)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Conv2d(torch::nn::Conv2dOptions(384, 256, 3).padding(1)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)),
torch::nn::Functional(modelsimpl::relu_),
torch::nn::Functional(modelsimpl::max_pool2d, 3, 2));
classifier = torch::nn::Sequential(
torch::nn::Dropout(),
torch::nn::Linear(256 * 6 * 6, 4096),
torch::nn::Functional(torch::relu),
torch::nn::Dropout(),
torch::nn::Linear(4096, 4096),
torch::nn::Functional(torch::relu),
torch::nn::Linear(4096, num_classes));
register_module("features", features);
register_module("classifier", classifier);
}
torch::Tensor AlexNetImpl::forward(torch::Tensor x) {
x = features->forward(x);
x = torch::adaptive_avg_pool2d(x, {6, 6});
x = x.view({x.size(0), -1});
x = classifier->forward(x);
return x;
}
} // namespace models
} // namespace vision
#ifndef ALEXNET_H
#define ALEXNET_H
#include <torch/torch.h>
namespace vision {
namespace models {
// AlexNet model architecture from the
// "One weird trick..." <https://arxiv.org/abs/1404.5997> paper.
struct AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr};
AlexNetImpl(int64_t num_classes = 1000);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(AlexNet);
} // namespace models
} // namespace vision
#endif // ALEXNET_H
#include "densenet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
struct _DenseLayerImpl : torch::nn::SequentialImpl {
double drop_rate;
_DenseLayerImpl(
int64_t num_input_features,
int64_t growth_rate,
int64_t bn_size,
double drop_rate)
: drop_rate(drop_rate) {
push_back("norm1", torch::nn::BatchNorm(num_input_features));
push_back("relu1", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv1",
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
.stride(1)
.with_bias(false)));
push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate));
push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv2",
torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3)
.stride(1)
.padding(1)
.with_bias(false)));
}
torch::Tensor forward(torch::Tensor x) {
auto new_features = torch::nn::SequentialImpl::forward(x);
if (drop_rate > 0)
new_features =
torch::dropout(new_features, drop_rate, this->is_training());
return torch::cat({x, new_features}, 1);
}
};
TORCH_MODULE(_DenseLayer);
struct _DenseBlockImpl : torch::nn::SequentialImpl {
_DenseBlockImpl(
int64_t num_layers,
int64_t num_input_features,
int64_t bn_size,
int64_t growth_rate,
double drop_rate) {
for (int64_t i = 0; i < num_layers; ++i) {
auto layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate,
bn_size,
drop_rate);
push_back("denselayer" + std::to_string(i + 1), layer);
}
}
torch::Tensor forward(torch::Tensor x) {
return torch::nn::SequentialImpl::forward(x);
}
};
TORCH_MODULE(_DenseBlock);
struct _TransitionImpl : torch::nn::SequentialImpl {
_TransitionImpl(int64_t num_input_features, int64_t num_output_features) {
push_back("norm", torch::nn::BatchNorm(num_input_features));
push_back("relu ", torch::nn::Functional(modelsimpl::relu_));
push_back(
"conv",
torch::nn::Conv2d(Options(num_input_features, num_output_features, 1)
.stride(1)
.with_bias(false)));
push_back(
"pool", torch::nn::Functional(torch::avg_pool2d, 2, 2, 0, false, true));
}
torch::Tensor forward(torch::Tensor x) {
return torch::nn::SequentialImpl::forward(x);
}
};
TORCH_MODULE(_Transition);
DenseNetImpl::DenseNetImpl(
int64_t num_classes,
int64_t growth_rate,
std::vector<int64_t> block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate) {
// First convolution
features = torch::nn::Sequential();
features->push_back(
"conv0",
torch::nn::Conv2d(Options(3, num_init_features, 7)
.stride(2)
.padding(3)
.with_bias(false)));
features->push_back("norm0", torch::nn::BatchNorm(num_init_features));
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
features->push_back(
"pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false));
// Each denseblock
auto num_features = num_init_features;
for (size_t i = 0; i < block_config.size(); ++i) {
auto num_layers = block_config[i];
_DenseBlock block(
num_layers, num_features, bn_size, growth_rate, drop_rate);
features->push_back("denseblock" + std::to_string(i + 1), block);
num_features = num_features + num_layers * growth_rate;
if (i != block_config.size() - 1) {
auto trans = _Transition(num_features, num_features / 2);
features->push_back("transition" + std::to_string(i + 1), trans);
num_features = num_features / 2;
}
}
// Final batch norm
features->push_back("norm5", torch::nn::BatchNorm(num_features));
// Linear layer
classifier = torch::nn::Linear(num_features, num_classes);
register_module("features", features);
register_module("classifier", classifier);
// Official init from torch repo.
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(M->weight);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::constant_(M->bias, 0);
}
}
torch::Tensor DenseNetImpl::forward(torch::Tensor x) {
auto features = this->features->forward(x);
auto out = torch::relu_(features);
out = torch::adaptive_avg_pool2d(out, {1, 1});
out = out.view({features.size(0), -1});
out = this->classifier->forward(out);
return out;
}
DenseNet121Impl::DenseNet121Impl(
int64_t num_classes,
int64_t growth_rate,
std::vector<int64_t> block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
DenseNet169Impl::DenseNet169Impl(
int64_t num_classes,
int64_t growth_rate,
std::vector<int64_t> block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
DenseNet201Impl::DenseNet201Impl(
int64_t num_classes,
int64_t growth_rate,
std::vector<int64_t> block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
DenseNet161Impl::DenseNet161Impl(
int64_t num_classes,
int64_t growth_rate,
std::vector<int64_t> block_config,
int64_t num_init_features,
int64_t bn_size,
double drop_rate)
: DenseNetImpl(
num_classes,
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate) {}
} // namespace models
} // namespace vision
#ifndef DENSENET_H
#define DENSENET_H
#include <torch/torch.h>
namespace vision {
namespace models {
// Densenet-BC model class, based on
// "Densely Connected Convolutional Networks"
// <https://arxiv.org/pdf/1608.06993.pdf>
// Args:
// num_classes (int) - number of classification classes
// growth_rate (int) - how many filters to add each layer (`k` in paper)
// block_config (list of 4 ints) - how many layers in each pooling block
// num_init_features (int) - the number of filters to learn in the first
// convolution layer
// bn_size (int) - multiplicative factor for number of bottle neck layers
// (i.e. bn_size * k features in the bottleneck layer)
// drop_rate (float) - dropout rate after each dense layer
struct DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr};
DenseNetImpl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
std::vector<int64_t> block_config = {6, 12, 24, 16},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
torch::Tensor forward(torch::Tensor x);
};
struct DenseNet121Impl : DenseNetImpl {
DenseNet121Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
std::vector<int64_t> block_config = {6, 12, 24, 16},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
};
struct DenseNet169Impl : DenseNetImpl {
DenseNet169Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
std::vector<int64_t> block_config = {6, 12, 32, 32},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
};
struct DenseNet201Impl : DenseNetImpl {
DenseNet201Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
std::vector<int64_t> block_config = {6, 12, 48, 32},
int64_t num_init_features = 64,
int64_t bn_size = 4,
double drop_rate = 0);
};
struct DenseNet161Impl : DenseNetImpl {
DenseNet161Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 48,
std::vector<int64_t> block_config = {6, 12, 36, 24},
int64_t num_init_features = 96,
int64_t bn_size = 4,
double drop_rate = 0);
};
TORCH_MODULE(DenseNet);
TORCH_MODULE(DenseNet121);
TORCH_MODULE(DenseNet169);
TORCH_MODULE(DenseNet201);
TORCH_MODULE(DenseNet161);
} // namespace models
} // namespace vision
#endif // DENSENET_H
#include "googlenet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
namespace _googlenetimpl {
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
options.with_bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm(
torch::nn::BatchNormOptions(options.output_channels()).eps(0.001));
register_module("conv", conv);
register_module("bn", bn);
}
torch::Tensor BasicConv2dImpl::forward(torch::Tensor x) {
x = conv->forward(x);
x = bn->forward(x);
return x.relu_();
}
InceptionImpl::InceptionImpl(
int64_t in_channels,
int64_t ch1x1,
int64_t ch3x3red,
int64_t ch3x3,
int64_t ch5x5red,
int64_t ch5x5,
int64_t pool_proj) {
branch1 = BasicConv2d(Options(in_channels, ch1x1, 1));
branch2->push_back(BasicConv2d(Options(in_channels, ch3x3red, 1)));
branch2->push_back(BasicConv2d(Options(ch3x3red, ch3x3, 3).padding(1)));
branch3->push_back(BasicConv2d(Options(in_channels, ch5x5red, 1)));
branch3->push_back(BasicConv2d(Options(ch5x5red, ch5x5, 3).padding(1)));
branch4->push_back(
torch::nn::Functional(torch::max_pool2d, 3, 1, 1, 1, true));
branch4->push_back(BasicConv2d(Options(in_channels, pool_proj, 1)));
register_module("branch1", branch1);
register_module("branch2", branch2);
register_module("branch3", branch3);
register_module("branch4", branch4);
}
torch::Tensor InceptionImpl::forward(torch::Tensor x) {
auto b1 = branch1->forward(x);
auto b2 = branch2->forward(x);
auto b3 = branch3->forward(x);
auto b4 = branch4->forward(x);
return torch::cat({b1, b2, b3, b4}, 1);
}
InceptionAuxImpl::InceptionAuxImpl(int64_t in_channels, int64_t num_classes) {
conv = BasicConv2d(Options(in_channels, 128, 1));
fc1 = torch::nn::Linear(2048, 1024);
fc2 = torch::nn::Linear(1024, num_classes);
register_module("conv", conv);
register_module("fc1", fc1);
register_module("fc2", fc2);
}
torch::Tensor InceptionAuxImpl::forward(at::Tensor x) {
// aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = torch::adaptive_avg_pool2d(x, {4, 4});
// aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = conv->forward(x);
// N x 128 x 4 x 4
x = x.view({x.size(0), -1});
// N x 2048
x = fc1->forward(x).relu_();
// N x 2048
x = torch::dropout(x, 0.7, is_training());
// N x 2048
x = fc2->forward(x);
// N x 1024
return x;
}
} // namespace _googlenetimpl
GoogLeNetImpl::GoogLeNetImpl(
int64_t num_classes,
bool aux_logits,
bool transform_input,
bool init_weights) {
this->aux_logits = aux_logits;
this->transform_input = transform_input;
conv1 = _googlenetimpl::BasicConv2d(Options(3, 64, 7).stride(2).padding(3));
conv2 = _googlenetimpl::BasicConv2d(Options(64, 64, 1));
conv3 = _googlenetimpl::BasicConv2d(Options(64, 192, 3).padding(1));
inception3a = _googlenetimpl::Inception(192, 64, 96, 128, 16, 32, 32);
inception3b = _googlenetimpl::Inception(256, 128, 128, 192, 32, 96, 64);
inception4a = _googlenetimpl::Inception(480, 192, 96, 208, 16, 48, 64);
inception4b = _googlenetimpl::Inception(512, 160, 112, 224, 24, 64, 64);
inception4c = _googlenetimpl::Inception(512, 128, 128, 256, 24, 64, 64);
inception4d = _googlenetimpl::Inception(512, 112, 144, 288, 32, 64, 64);
inception4e = _googlenetimpl::Inception(528, 256, 160, 320, 32, 128, 128);
inception5a = _googlenetimpl::Inception(832, 256, 160, 320, 32, 128, 128);
inception5b = _googlenetimpl::Inception(832, 384, 192, 384, 48, 128, 128);
if (aux_logits) {
aux1 = _googlenetimpl::InceptionAux(512, num_classes);
aux2 = _googlenetimpl::InceptionAux(528, num_classes);
register_module("aux1", aux1);
register_module("aux2", aux2);
}
dropout = torch::nn::Dropout(0.2);
fc = torch::nn::Linear(1024, num_classes);
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("inception3a", inception3a);
register_module("inception3b", inception3b);
register_module("inception4a", inception4a);
register_module("inception4b", inception4b);
register_module("inception4c", inception4c);
register_module("inception4d", inception4d);
register_module("inception4e", inception4e);
register_module("inception5a", inception5a);
register_module("inception5b", inception5b);
register_module("dropout", dropout);
register_module("fc", fc);
if (init_weights)
_initialize_weights();
}
void GoogLeNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::normal_(M->weight); // Note: used instead of truncated
// normal initialization
else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::normal_(M->weight); // Note: used instead of truncated
// normal initialization
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
}
}
}
GoogLeNetOutput GoogLeNetImpl::forward(torch::Tensor x) {
if (transform_input) {
auto x_ch0 = torch::unsqueeze(x.select(1, 0), 1) * (0.229 / 0.5) +
(0.485 - 0.5) / 0.5;
auto x_ch1 = torch::unsqueeze(x.select(1, 1), 1) * (0.224 / 0.5) +
(0.456 - 0.5) / 0.5;
auto x_ch2 = torch::unsqueeze(x.select(1, 2), 1) * (0.225 / 0.5) +
(0.406 - 0.5) / 0.5;
x = torch::cat({x_ch0, x_ch1, x_ch2}, 1);
}
// N x 3 x 224 x 224
x = conv1->forward(x);
// N x 64 x 112 x 112
x = torch::max_pool2d(x, 3, 2, 0, 1, true);
// N x 64 x 56 x 56
x = conv2->forward(x);
// N x 64 x 56 x 56
x = conv3->forward(x);
// N x 192 x 56 x 56
x = torch::max_pool2d(x, 3, 2, 0, 1, true);
// N x 192 x 28 x 28
x = inception3a->forward(x);
// N x 256 x 28 x 28
x = inception3b->forward(x);
// N x 480 x 28 x 28
x = torch::max_pool2d(x, 3, 2, 0, 1, true);
// N x 480 x 14 x 14
x = inception4a->forward(x);
// N x 512 x 14 x 14
torch::Tensor aux1;
if (is_training() && aux_logits)
aux1 = this->aux1->forward(x);
x = inception4b->forward(x);
// N x 512 x 14 x 14
x = inception4c->forward(x);
// N x 512 x 14 x 14
x = inception4d->forward(x);
// N x 528 x 14 x 14
torch::Tensor aux2;
if (is_training() && aux_logits)
aux2 = this->aux2->forward(x);
x = inception4e(x);
// N x 832 x 14 x 14
x = torch::max_pool2d(x, 2, 2, 0, 1, true);
// N x 832 x 7 x 7
x = inception5a(x);
// N x 832 x 7 x 7
x = inception5b(x);
// N x 1024 x 7 x 7
x = torch::adaptive_avg_pool2d(x, {1, 1});
// N x 1024 x 1 x 1
x = x.view({x.size(0), -1});
// N x 1024
x = dropout->forward(x);
x = fc->forward(x);
// N x 1000(num_classes)
return {x, aux1, aux2};
}
} // namespace models
} // namespace vision
#ifndef GOOGLENET_H
#define GOOGLENET_H
#include <torch/torch.h>
namespace vision {
namespace models {
namespace _googlenetimpl {
struct BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(BasicConv2d);
struct InceptionImpl : torch::nn::Module {
BasicConv2d branch1{nullptr};
torch::nn::Sequential branch2, branch3, branch4;
InceptionImpl(
int64_t in_channels,
int64_t ch1x1,
int64_t ch3x3red,
int64_t ch3x3,
int64_t ch5x5red,
int64_t ch5x5,
int64_t pool_proj);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(Inception);
struct InceptionAuxImpl : torch::nn::Module {
BasicConv2d conv{nullptr};
torch::nn::Linear fc1{nullptr}, fc2{nullptr};
InceptionAuxImpl(int64_t in_channels, int64_t num_classes);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(InceptionAux);
} // namespace _googlenetimpl
struct GoogLeNetOutput {
torch::Tensor output;
torch::Tensor aux1;
torch::Tensor aux2;
};
struct GoogLeNetImpl : torch::nn::Module {
bool aux_logits, transform_input;
_googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
_googlenetimpl::Inception inception3a{nullptr}, inception3b{nullptr},
inception4a{nullptr}, inception4b{nullptr}, inception4c{nullptr},
inception4d{nullptr}, inception4e{nullptr}, inception5a{nullptr},
inception5b{nullptr};
_googlenetimpl::InceptionAux aux1{nullptr}, aux2{nullptr};
torch::nn::Dropout dropout{nullptr};
torch::nn::Linear fc{nullptr};
GoogLeNetImpl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false,
bool init_weights = true);
void _initialize_weights();
GoogLeNetOutput forward(torch::Tensor x);
};
TORCH_MODULE(GoogLeNet);
} // namespace models
} // namespace vision
#endif // GOOGLENET_H
#include "inception.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
namespace _inceptionimpl {
BasicConv2dImpl::BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev) {
options.with_bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm(
torch::nn::BatchNormOptions(options.output_channels()).eps(0.001));
register_module("conv", conv);
register_module("bn", bn);
torch::nn::init::normal_(
conv->weight,
0,
std_dev); // Note: used instead of truncated normal initialization
torch::nn::init::constant_(bn->weight, 1);
torch::nn::init::constant_(bn->bias, 0);
}
torch::Tensor BasicConv2dImpl::forward(torch::Tensor x) {
x = conv->forward(x);
x = bn->forward(x);
return torch::relu_(x);
}
InceptionAImpl::InceptionAImpl(int64_t in_channels, int64_t pool_features)
: branch1x1(Options(in_channels, 64, 1)),
branch5x5_1(Options(in_channels, 48, 1)),
branch5x5_2(Options(48, 64, 5).padding(2)),
branch3x3dbl_1(Options(in_channels, 64, 1)),
branch3x3dbl_2(Options(64, 96, 3).padding(1)),
branch3x3dbl_3(Options(96, 96, 3).padding(1)),
branch_pool(Options(in_channels, pool_features, 1)) {
register_module("branch1x1", branch1x1);
register_module("branch5x5_1", branch5x5_1);
register_module("branch5x5_2", branch5x5_2);
register_module("branch3x3dbl_1", branch3x3dbl_1);
register_module("branch3x3dbl_2", branch3x3dbl_2);
register_module("branch3x3dbl_3", branch3x3dbl_3);
register_module("branch_pool", branch_pool);
}
torch::Tensor InceptionAImpl::forward(torch::Tensor x) {
auto branch1x1 = this->branch1x1->forward(x);
auto branch5x5 = this->branch5x5_1->forward(x);
branch5x5 = this->branch5x5_2->forward(branch5x5);
auto branch3x3dbl = this->branch3x3dbl_1->forward(x);
branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl);
branch3x3dbl = this->branch3x3dbl_3->forward(branch3x3dbl);
auto branch_pool = torch::avg_pool2d(x, 3, 1, 1);
branch_pool = this->branch_pool->forward(branch_pool);
return torch::cat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, 1);
}
InceptionBImpl::InceptionBImpl(int64_t in_channels)
: branch3x3(Options(in_channels, 384, 3).stride(2)),
branch3x3dbl_1(Options(in_channels, 64, 1)),
branch3x3dbl_2(Options(64, 96, 3).padding(1)),
branch3x3dbl_3(Options(96, 96, 3).stride(2)) {
register_module("branch3x3", branch3x3);
register_module("branch3x3dbl_1", branch3x3dbl_1);
register_module("branch3x3dbl_2", branch3x3dbl_2);
register_module("branch3x3dbl_3", branch3x3dbl_3);
}
torch::Tensor InceptionBImpl::forward(torch::Tensor x) {
auto branch3x3 = this->branch3x3->forward(x);
auto branch3x3dbl = this->branch3x3dbl_1->forward(x);
branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl);
branch3x3dbl = this->branch3x3dbl_3->forward(branch3x3dbl);
auto branch_pool = torch::max_pool2d(x, 3, 2);
return torch::cat({branch3x3, branch3x3dbl, branch_pool}, 1);
}
InceptionCImpl::InceptionCImpl(int64_t in_channels, int64_t channels_7x7) {
branch1x1 = BasicConv2d(Options(in_channels, 192, 1));
auto c7 = channels_7x7;
branch7x7_1 = BasicConv2d(Options(in_channels, c7, 1));
branch7x7_2 = BasicConv2d(Options(c7, c7, {1, 7}).padding({0, 3}));
branch7x7_3 = BasicConv2d(Options(c7, 192, {7, 1}).padding({3, 0}));
branch7x7dbl_1 = BasicConv2d(Options(in_channels, c7, 1));
branch7x7dbl_2 = BasicConv2d(Options(c7, c7, {7, 1}).padding({3, 0}));
branch7x7dbl_3 = BasicConv2d(Options(c7, c7, {1, 7}).padding({0, 3}));
branch7x7dbl_4 = BasicConv2d(Options(c7, c7, {7, 1}).padding({3, 0}));
branch7x7dbl_5 = BasicConv2d(Options(c7, 192, {1, 7}).padding({0, 3}));
branch_pool = BasicConv2d(Options(in_channels, 192, 1));
register_module("branch1x1", branch1x1);
register_module("branch7x7_1", branch7x7_1);
register_module("branch7x7_2", branch7x7_2);
register_module("branch7x7_3", branch7x7_3);
register_module("branch7x7dbl_1", branch7x7dbl_1);
register_module("branch7x7dbl_2", branch7x7dbl_2);
register_module("branch7x7dbl_3", branch7x7dbl_3);
register_module("branch7x7dbl_4", branch7x7dbl_4);
register_module("branch7x7dbl_5", branch7x7dbl_5);
register_module("branch_pool", branch_pool);
}
torch::Tensor InceptionCImpl::forward(torch::Tensor x) {
auto branch1x1 = this->branch1x1->forward(x);
auto branch7x7 = this->branch7x7_1->forward(x);
branch7x7 = this->branch7x7_2->forward(branch7x7);
branch7x7 = this->branch7x7_3->forward(branch7x7);
auto branch7x7dbl = this->branch7x7dbl_1->forward(x);
branch7x7dbl = this->branch7x7dbl_2->forward(branch7x7dbl);
branch7x7dbl = this->branch7x7dbl_3->forward(branch7x7dbl);
branch7x7dbl = this->branch7x7dbl_4->forward(branch7x7dbl);
branch7x7dbl = this->branch7x7dbl_5->forward(branch7x7dbl);
auto branch_pool = torch::avg_pool2d(x, 3, 1, 1);
branch_pool = this->branch_pool->forward(branch_pool);
return torch::cat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, 1);
}
InceptionDImpl::InceptionDImpl(int64_t in_channels)
: branch3x3_1(Options(in_channels, 192, 1)),
branch3x3_2(Options(192, 320, 3).stride(2)),
branch7x7x3_1(Options(in_channels, 192, 1)),
branch7x7x3_2(Options(192, 192, {1, 7}).padding({0, 3})),
branch7x7x3_3(Options(192, 192, {7, 1}).padding({3, 0})),
branch7x7x3_4(Options(192, 192, 3).stride(2))
{
register_module("branch3x3_1", branch3x3_1);
register_module("branch3x3_2", branch3x3_2);
register_module("branch7x7x3_1", branch7x7x3_1);
register_module("branch7x7x3_2", branch7x7x3_2);
register_module("branch7x7x3_3", branch7x7x3_3);
register_module("branch7x7x3_4", branch7x7x3_4);
}
torch::Tensor InceptionDImpl::forward(torch::Tensor x) {
auto branch3x3 = this->branch3x3_1->forward(x);
branch3x3 = this->branch3x3_2->forward(branch3x3);
auto branch7x7x3 = this->branch7x7x3_1->forward(x);
branch7x7x3 = this->branch7x7x3_2->forward(branch7x7x3);
branch7x7x3 = this->branch7x7x3_3->forward(branch7x7x3);
branch7x7x3 = this->branch7x7x3_4->forward(branch7x7x3);
auto branch_pool = torch::max_pool2d(x, 3, 2);
return torch::cat({branch3x3, branch7x7x3, branch_pool}, 1);
}
InceptionEImpl::InceptionEImpl(int64_t in_channels)
: branch1x1(Options(in_channels, 320, 1)),
branch3x3_1(Options(in_channels, 384, 1)),
branch3x3_2a(Options(384, 384, {1, 3}).padding({0, 1})),
branch3x3_2b(Options(384, 384, {3, 1}).padding({1, 0})),
branch3x3dbl_1(Options(in_channels, 448, 1)),
branch3x3dbl_2(Options(448, 384, 3).padding(1)),
branch3x3dbl_3a(Options(384, 384, {1, 3}).padding({0, 1})),
branch3x3dbl_3b(Options(384, 384, {3, 1}).padding({1, 0})),
branch_pool(Options(in_channels, 192, 1)) {
register_module("branch1x1", branch1x1);
register_module("branch3x3_1", branch3x3_1);
register_module("branch3x3_2a", branch3x3_2a);
register_module("branch3x3_2b", branch3x3_2b);
register_module("branch3x3dbl_1", branch3x3dbl_1);
register_module("branch3x3dbl_2", branch3x3dbl_2);
register_module("branch3x3dbl_3a", branch3x3dbl_3a);
register_module("branch3x3dbl_3b", branch3x3dbl_3b);
register_module("branch_pool", branch_pool);
}
torch::Tensor InceptionEImpl::forward(torch::Tensor x) {
auto branch1x1 = this->branch1x1->forward(x);
auto branch3x3 = this->branch3x3_1->forward(x);
branch3x3 = torch::cat(
{
this->branch3x3_2a->forward(branch3x3),
this->branch3x3_2b->forward(branch3x3),
},
1);
auto branch3x3dbl = this->branch3x3dbl_1->forward(x);
branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl);
branch3x3dbl = torch::cat(
{this->branch3x3dbl_3a->forward(branch3x3dbl),
this->branch3x3dbl_3b->forward(branch3x3dbl)},
1);
auto branch_pool = torch::avg_pool2d(x, 3, 1, 1);
branch_pool = this->branch_pool->forward(branch_pool);
return torch::cat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, 1);
}
InceptionAuxImpl::InceptionAuxImpl(int64_t in_channels, int64_t num_classes)
: conv0(BasicConv2d(Options(in_channels, 128, 1))),
conv1(BasicConv2d(Options(128, 768, 5), 0.01)),
fc(768, num_classes) {
torch::nn::init::normal_(
fc->weight,
0,
0.001); // Note: used instead of truncated normal initialization
register_module("conv0", conv0);
register_module("conv1", conv1);
register_module("fc", fc);
}
torch::Tensor InceptionAuxImpl::forward(torch::Tensor x) {
// N x 768 x 17 x 17
x = torch::avg_pool2d(x, 5, 3);
// N x 768 x 5 x 5
x = conv0->forward(x);
// N x 128 x 5 x 5
x = conv1->forward(x);
// N x 768 x 1 x 1
x = torch::adaptive_avg_pool2d(x, {1, 1});
// N x 768 x 1 x 1
x = x.view({x.size(0), -1});
// N x 768
x = fc->forward(x);
// N x 1000 (num_classes)
return x;
}
} // namespace _inceptionimpl
InceptionV3Impl::InceptionV3Impl(
int64_t num_classes,
bool aux_logits,
bool transform_input)
: aux_logits(aux_logits), transform_input(transform_input) {
Conv2d_1a_3x3 = _inceptionimpl::BasicConv2d(Options(3, 32, 3).stride(2));
Conv2d_2a_3x3 = _inceptionimpl::BasicConv2d(Options(32, 32, 3));
Conv2d_2b_3x3 = _inceptionimpl::BasicConv2d(Options(32, 64, 3).padding(1));
Conv2d_3b_1x1 = _inceptionimpl::BasicConv2d(Options(64, 80, 1));
Conv2d_4a_3x3 = _inceptionimpl::BasicConv2d(Options(80, 192, 3));
Mixed_5b = _inceptionimpl::InceptionA(192, 32);
Mixed_5c = _inceptionimpl::InceptionA(256, 64);
Mixed_5d = _inceptionimpl::InceptionA(288, 64);
Mixed_6a = _inceptionimpl::InceptionB(288);
Mixed_6b = _inceptionimpl::InceptionC(768, 128);
Mixed_6c = _inceptionimpl::InceptionC(768, 160);
Mixed_6d = _inceptionimpl::InceptionC(768, 160);
Mixed_6e = _inceptionimpl::InceptionC(768, 192);
if (aux_logits)
AuxLogits = _inceptionimpl::InceptionAux(768, num_classes);
Mixed_7a = _inceptionimpl::InceptionD(768);
Mixed_7b = _inceptionimpl::InceptionE(1280);
Mixed_7c = _inceptionimpl::InceptionE(2048);
fc = torch::nn::Linear(2048, num_classes);
torch::nn::init::normal_(
fc->weight,
0,
0.1); // Note: used instead of truncated normal initialization
register_module("Conv2d_1a_3x3", Conv2d_1a_3x3);
register_module("Conv2d_2a_3x3", Conv2d_2a_3x3);
register_module("Conv2d_2b_3x3", Conv2d_2b_3x3);
register_module("Conv2d_3b_1x1", Conv2d_3b_1x1);
register_module("Conv2d_4a_3x3", Conv2d_4a_3x3);
register_module("Mixed_5b", Mixed_5b);
register_module("Mixed_5c", Mixed_5c);
register_module("Mixed_5d", Mixed_5d);
register_module("Mixed_6a", Mixed_6a);
register_module("Mixed_6b", Mixed_6b);
register_module("Mixed_6c", Mixed_6c);
register_module("Mixed_6d", Mixed_6d);
register_module("Mixed_6e", Mixed_6e);
if (!AuxLogits.is_empty())
register_module("AuxLogits", AuxLogits);
register_module("Mixed_7a", Mixed_7a);
register_module("Mixed_7b", Mixed_7b);
register_module("Mixed_7c", Mixed_7c);
register_module("fc", fc);
}
InceptionV3Output InceptionV3Impl::forward(torch::Tensor x) {
if (transform_input) {
auto x_ch0 = torch::unsqueeze(x.select(1, 0), 1) * (0.229 / 0.5) +
(0.485 - 0.5) / 0.5;
auto x_ch1 = torch::unsqueeze(x.select(1, 1), 1) * (0.224 / 0.5) +
(0.456 - 0.5) / 0.5;
auto x_ch2 = torch::unsqueeze(x.select(1, 2), 1) * (0.225 / 0.5) +
(0.406 - 0.5) / 0.5;
x = torch::cat({x_ch0, x_ch1, x_ch2}, 1);
}
// N x 3 x 299 x 299
x = Conv2d_1a_3x3->forward(x);
// N x 32 x 149 x 149
x = Conv2d_2a_3x3->forward(x);
// N x 32 x 147 x 147
x = Conv2d_2b_3x3->forward(x);
// N x 64 x 147 x 147
x = torch::max_pool2d(x, 3, 2);
// N x 64 x 73 x 73
x = Conv2d_3b_1x1->forward(x);
// N x 80 x 73 x 73
x = Conv2d_4a_3x3->forward(x);
// N x 192 x 71 x 71
x = torch::max_pool2d(x, 3, 2);
// N x 192 x 35 x 35
x = Mixed_5b->forward(x);
// N x 256 x 35 x 35
x = Mixed_5c->forward(x);
// N x 288 x 35 x 35
x = Mixed_5d->forward(x);
// N x 288 x 35 x 35
x = Mixed_6a->forward(x);
// N x 768 x 17 x 17
x = Mixed_6b->forward(x);
// N x 768 x 17 x 17
x = Mixed_6c->forward(x);
// N x 768 x 17 x 17
x = Mixed_6d->forward(x);
// N x 768 x 17 x 17
x = Mixed_6e->forward(x);
// N x 768 x 17 x 17
torch::Tensor aux;
if (is_training() && aux_logits)
aux = AuxLogits->forward(x);
// N x 768 x 17 x 17
x = Mixed_7a->forward(x);
// N x 1280 x 8 x 8
x = Mixed_7b->forward(x);
// N x 2048 x 8 x 8
x = Mixed_7c->forward(x);
// N x 2048 x 8 x 8
x = torch::adaptive_avg_pool2d(x, {1, 1});
// N x 2048 x 1 x 1
x = torch::dropout(x, 0.5, is_training());
// N x 2048 x 1 x 1
x = x.view({x.size(0), -1});
// N x 2048
x = fc->forward(x);
// N x 1000 (num_classes)
if (is_training() && aux_logits)
return {x, aux};
return {x, {}};
}
// namespace _inceptionimpl
} // namespace models
} // namespace vision
#ifndef INCEPTION_H
#define INCEPTION_H
#include <torch/torch.h>
namespace vision {
namespace models {
namespace _inceptionimpl {
struct BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(BasicConv2d);
struct InceptionAImpl : torch::nn::Module {
BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
branch3x3dbl_2, branch3x3dbl_3, branch_pool;
InceptionAImpl(int64_t in_channels, int64_t pool_features);
torch::Tensor forward(torch::Tensor x);
};
struct InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;
InceptionBImpl(int64_t in_channels);
torch::Tensor forward(torch::Tensor x);
};
struct InceptionCImpl : torch::nn::Module {
BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
branch_pool{nullptr};
InceptionCImpl(int64_t in_channels, int64_t channels_7x7);
torch::Tensor forward(torch::Tensor x);
};
struct InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4;
InceptionDImpl(int64_t in_channels);
torch::Tensor forward(torch::Tensor x);
};
struct InceptionEImpl : torch::nn::Module {
BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool;
InceptionEImpl(int64_t in_channels);
torch::Tensor forward(torch::Tensor x);
};
struct InceptionAuxImpl : torch::nn::Module {
BasicConv2d conv0;
BasicConv2d conv1;
torch::nn::Linear fc;
InceptionAuxImpl(int64_t in_channels, int64_t num_classes);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(InceptionA);
TORCH_MODULE(InceptionB);
TORCH_MODULE(InceptionC);
TORCH_MODULE(InceptionD);
TORCH_MODULE(InceptionE);
TORCH_MODULE(InceptionAux);
} // namespace _inceptionimpl
struct InceptionV3Output {
torch::Tensor output;
torch::Tensor aux;
};
// Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567>
struct InceptionV3Impl : torch::nn::Module {
bool aux_logits, transform_input;
_inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr};
_inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr},
Mixed_5d{nullptr};
_inceptionimpl::InceptionB Mixed_6a{nullptr};
_inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr},
Mixed_6d{nullptr}, Mixed_6e{nullptr};
_inceptionimpl::InceptionD Mixed_7a{nullptr};
_inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr};
torch::nn::Linear fc{nullptr};
_inceptionimpl::InceptionAux AuxLogits{nullptr};
InceptionV3Impl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false);
InceptionV3Output forward(torch::Tensor x);
};
TORCH_MODULE(InceptionV3);
} // namespace models
} // namespace vision
#endif // INCEPTION_H
#include "mobilenet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
struct ConvBNReLUImpl : torch::nn::SequentialImpl {
ConvBNReLUImpl(
int64_t in_planes,
int64_t out_planes,
int64_t kernel_size = 3,
int64_t stride = 1,
int64_t groups = 1) {
auto padding = (kernel_size - 1) / 2;
push_back(torch::nn::Conv2d(Options(in_planes, out_planes, kernel_size)
.stride(stride)
.padding(padding)
.groups(groups)
.with_bias(false)));
push_back(torch::nn::BatchNorm(out_planes));
push_back(torch::nn::Functional(modelsimpl::relu6_));
}
torch::Tensor forward(torch::Tensor x) {
return torch::nn::SequentialImpl::forward(x);
}
};
TORCH_MODULE(ConvBNReLU);
struct MobileNetInvertedResidualImpl : torch::nn::Module {
int64_t stride;
bool use_res_connect;
torch::nn::Sequential conv;
MobileNetInvertedResidualImpl(
int64_t input,
int64_t output,
int64_t stride,
double expand_ratio)
: stride(stride), use_res_connect(stride == 1 && input == output) {
auto double_compare = [](double a, double b) {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};
assert(stride == 1 || stride == 2);
auto hidden_dim = int64_t(std::round(input * expand_ratio));
if (!double_compare(expand_ratio, 1))
conv->push_back(ConvBNReLU(input, hidden_dim, 1));
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
conv->push_back(torch::nn::Conv2d(
Options(hidden_dim, output, 1).stride(1).padding(0).with_bias(false)));
conv->push_back(torch::nn::BatchNorm(output));
register_module("conv", conv);
}
torch::Tensor forward(torch::Tensor x) {
if (use_res_connect)
return x + conv->forward(x);
return conv->forward(x);
}
};
TORCH_MODULE(MobileNetInvertedResidual);
MobileNetV2Impl::MobileNetV2Impl(int64_t num_classes, double width_mult) {
using Block = MobileNetInvertedResidual;
int64_t input_channel = 32;
int64_t last_channel = 1280;
std::vector<std::vector<int64_t>> inverted_residual_settings = {
// t, c, n, s
{1, 16, 1, 1},
{6, 24, 2, 2},
{6, 32, 3, 2},
{6, 64, 4, 2},
{6, 96, 3, 1},
{6, 160, 3, 2},
{6, 320, 1, 1},
};
input_channel = int64_t(input_channel * width_mult);
this->last_channel = int64_t(last_channel * std::max(1.0, width_mult));
features->push_back(ConvBNReLU(3, input_channel, 3, 2));
for (auto setting : inverted_residual_settings) {
auto output_channel = int64_t(setting[1] * width_mult);
for (int64_t i = 0; i < setting[2]; ++i) {
auto stride = i == 0 ? setting[3] : 1;
features->push_back(
Block(input_channel, output_channel, stride, setting[0]));
input_channel = output_channel;
}
}
features->push_back(ConvBNReLU(input_channel, this->last_channel, 1));
classifier->push_back(torch::nn::Dropout(0.2));
classifier->push_back(torch::nn::Linear(this->last_channel, num_classes));
register_module("features", features);
register_module("classifier", classifier);
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut);
if (M->options.with_bias())
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
torch::nn::init::normal_(M->weight, 0, 0.01);
torch::nn::init::zeros_(M->bias);
}
}
}
torch::Tensor MobileNetV2Impl::forward(at::Tensor x) {
x = features->forward(x);
x = x.mean({2, 3});
x = classifier->forward(x);
return x;
}
} // namespace models
} // namespace vision
#ifndef MOBILENET_H
#define MOBILENET_H
#include <torch/torch.h>
namespace vision {
namespace models {
struct MobileNetV2Impl : torch::nn::Module {
int64_t last_channel;
torch::nn::Sequential features, classifier;
MobileNetV2Impl(int64_t num_classes = 1000, double width_mult = 1.0);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(MobileNetV2);
} // namespace models
} // namespace vision
#endif // MOBILENET_H
#ifndef MODELS_H
#define MODELS_H
#include "alexnet.h"
#include "densenet.h"
#include "googlenet.h"
#include "inception.h"
#include "mobilenet.h"
#include "resnet.h"
#include "shufflenetv2.h"
#include "squeezenet.h"
#include "vgg.h"
#endif // MODELS_H
#ifndef MODELSIMPL_H
#define MODELSIMPL_H
#include <torch/torch.h>
namespace vision {
namespace models {
namespace modelsimpl {
// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in
// torch::nn::Fuctional don't work. so keeping these for now
inline torch::Tensor& relu_(torch::Tensor x) {
return torch::relu_(x);
}
inline torch::Tensor relu6_(torch::Tensor x) {
return torch::clamp_(x, 0, 6);
}
inline torch::Tensor adaptive_avg_pool2d(
torch::Tensor x,
torch::ExpandingArray<2> output_size) {
return torch::adaptive_avg_pool2d(x, output_size);
}
inline torch::Tensor max_pool2d(
torch::Tensor x,
torch::ExpandingArray<2> kernel_size,
torch::ExpandingArray<2> stride) {
return torch::max_pool2d(x, kernel_size, stride);
}
inline bool double_compare(double a, double b) {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};
} // namespace modelsimpl
} // namespace models
} // namespace vision
#endif // MODELSIMPL_H
#include "resnet.h"
namespace vision {
namespace models {
namespace _resnetimpl {
torch::nn::Conv2d conv3x3(
int64_t in,
int64_t out,
int64_t stride,
int64_t groups) {
torch::nn::Conv2dOptions O(in, out, 3);
O.padding(1).stride(stride).groups(groups).with_bias(false);
return torch::nn::Conv2d(O);
}
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) {
torch::nn::Conv2dOptions O(in, out, 1);
O.stride(stride).with_bias(false);
return torch::nn::Conv2d(O);
}
int BasicBlock::expansion = 1;
int Bottleneck::expansion = 4;
BasicBlock::BasicBlock(
int64_t inplanes,
int64_t planes,
int64_t stride,
torch::nn::Sequential downsample,
int64_t groups,
int64_t base_width)
: stride(stride), downsample(downsample) {
if (groups != 1 or base_width != 64) {
std::cerr << "BasicBlock only supports groups=1 and base_width=64"
<< std::endl;
assert(false);
}
// Both conv1 and downsample layers downsample the input when stride != 1
conv1 = conv3x3(inplanes, planes, stride);
conv2 = conv3x3(planes, planes);
bn1 = torch::nn::BatchNorm(planes);
bn2 = torch::nn::BatchNorm(planes);
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("bn1", bn1);
register_module("bn2", bn2);
if (!downsample.is_empty())
register_module("downsample", this->downsample);
}
Bottleneck::Bottleneck(
int64_t inplanes,
int64_t planes,
int64_t stride,
torch::nn::Sequential downsample,
int64_t groups,
int64_t base_width)
: stride(stride), downsample(downsample) {
auto width = int64_t(planes * (base_width / 64.)) * groups;
// Both conv2 and downsample layers downsample the input when stride != 1
conv1 = conv1x1(inplanes, width);
conv2 = conv3x3(width, width, stride, groups);
conv3 = conv1x1(width, planes * expansion);
bn1 = torch::nn::BatchNorm(width);
bn2 = torch::nn::BatchNorm(width);
bn3 = torch::nn::BatchNorm(planes * expansion);
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("bn1", bn1);
register_module("bn2", bn2);
register_module("bn3", bn3);
if (!downsample.is_empty())
register_module("downsample", this->downsample);
}
torch::Tensor Bottleneck::forward(torch::Tensor X) {
auto identity = X;
auto out = conv1->forward(X);
out = bn1->forward(out).relu_();
out = conv2->forward(out);
out = bn2->forward(out).relu_();
out = conv3->forward(out);
out = bn3->forward(out);
if (!downsample.is_empty())
identity = downsample->forward(X);
out += identity;
return out.relu_();
}
torch::Tensor BasicBlock::forward(torch::Tensor x) {
auto identity = x;
auto out = conv1->forward(x);
out = bn1->forward(out).relu_();
out = conv2->forward(out);
out = bn2->forward(out);
if (!downsample.is_empty())
identity = downsample->forward(x);
out += identity;
return out.relu_();
}
} // namespace _resnetimpl
ResNet18Impl::ResNet18Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({2, 2, 2, 2}, num_classes, zero_init_residual) {}
ResNet34Impl::ResNet34Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {}
ResNet50Impl::ResNet50Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {}
ResNet101Impl::ResNet101Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual) {}
ResNet152Impl::ResNet152Impl(int64_t num_classes, bool zero_init_residual)
: ResNetImpl({3, 8, 36, 3}, num_classes, zero_init_residual) {}
ResNext50_32x4dImpl::ResNext50_32x4dImpl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 32, 4) {}
ResNext101_32x8dImpl::ResNext101_32x8dImpl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {}
} // namespace models
} // namespace vision
#ifndef RESNET_H
#define RESNET_H
#include <torch/torch.h>
namespace vision {
namespace models {
template <typename Block>
struct ResNetImpl;
namespace _resnetimpl {
// 3x3 convolution with padding
torch::nn::Conv2d conv3x3(
int64_t in,
int64_t out,
int64_t stride = 1,
int64_t groups = 1);
// 1x1 convolution
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1);
struct BasicBlock : torch::nn::Module {
template <typename Block>
friend struct vision::models::ResNetImpl;
int64_t stride;
torch::nn::Sequential downsample;
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr};
static int expansion;
BasicBlock(
int64_t inplanes,
int64_t planes,
int64_t stride = 1,
torch::nn::Sequential downsample = nullptr,
int64_t groups = 1,
int64_t base_width = 64);
torch::Tensor forward(torch::Tensor x);
};
struct Bottleneck : torch::nn::Module {
template <typename Block>
friend struct vision::models::ResNetImpl;
int64_t stride;
torch::nn::Sequential downsample;
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
static int expansion;
Bottleneck(
int64_t inplanes,
int64_t planes,
int64_t stride = 1,
torch::nn::Sequential downsample = nullptr,
int64_t groups = 1,
int64_t base_width = 64);
torch::Tensor forward(torch::Tensor X);
};
} // namespace _resnetimpl
template <typename Block>
struct ResNetImpl : torch::nn::Module {
int64_t groups, base_width, inplanes;
torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1;
torch::nn::Linear fc;
torch::nn::Sequential layer1, layer2, layer3, layer4;
torch::nn::Sequential _make_layer(
int64_t planes,
int64_t blocks,
int64_t stride = 1);
ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes = 1000,
bool zero_init_residual = false,
int64_t groups = 1,
int64_t width_per_group = 64);
torch::Tensor forward(torch::Tensor X);
};
template <typename Block>
torch::nn::Sequential ResNetImpl<Block>::_make_layer(
int64_t planes,
int64_t blocks,
int64_t stride) {
torch::nn::Sequential downsample = nullptr;
if (stride != 1 || inplanes != planes * Block::expansion) {
downsample = torch::nn::Sequential(
_resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
torch::nn::BatchNorm(planes * Block::expansion));
}
torch::nn::Sequential layers;
layers->push_back(
Block(inplanes, planes, stride, downsample, groups, base_width));
inplanes = planes * Block::expansion;
for (int i = 1; i < blocks; ++i)
layers->push_back(Block(inplanes, planes, 1, nullptr, groups, base_width));
return layers;
}
template <typename Block>
ResNetImpl<Block>::ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes,
bool zero_init_residual,
int64_t groups,
int64_t width_per_group)
: groups(groups),
base_width(width_per_group),
inplanes(64),
conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).with_bias(
false)),
bn1(64),
layer1(_make_layer(64, layers[0])),
layer2(_make_layer(128, layers[1], 2)),
layer3(_make_layer(256, layers[2], 2)),
layer4(_make_layer(512, layers[3], 2)),
fc(512 * Block::expansion, num_classes) {
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("fc", fc);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
}
}
// Zero-initialize the last BN in each residual branch, so that the residual
// branch starts with zeros, and each residual block behaves like an
// identity. This improves the model by 0.2~0.3% according to
// https://arxiv.org/abs/1706.02677
if (zero_init_residual)
for (auto& module : modules(/*include_self=*/false)) {
if (auto* M = dynamic_cast<_resnetimpl::Bottleneck*>(module.get()))
torch::nn::init::constant_(M->bn3->weight, 0);
else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get()))
torch::nn::init::constant_(M->bn2->weight, 0);
}
}
template <typename Block>
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
x = conv1->forward(x);
x = bn1->forward(x).relu_();
x = torch::max_pool2d(x, 3, 2, 1);
x = layer1->forward(x);
x = layer2->forward(x);
x = layer3->forward(x);
x = layer4->forward(x);
x = torch::adaptive_avg_pool2d(x, {1, 1});
x = x.reshape({x.size(0), -1});
x = fc->forward(x);
return x;
}
struct ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};
struct ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};
struct ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};
struct ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};
struct ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};
struct ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext50_32x4dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext101_32x8dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
template <typename Block>
struct ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
};
TORCH_MODULE(ResNet18);
TORCH_MODULE(ResNet34);
TORCH_MODULE(ResNet50);
TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d);
} // namespace models
} // namespace vision
#endif // RESNET_H
#include "shufflenetv2.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;
torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) {
auto shape = x.sizes();
auto batchsize = shape[0];
auto num_channels = shape[1];
auto height = shape[2];
auto width = shape[3];
auto channels_per_group = num_channels / groups;
x = x.view({batchsize, groups, channels_per_group, height, width});
x = torch::transpose(x, 1, 2).contiguous();
x = x.view({batchsize, -1, height, width});
return x;
}
torch::nn::Conv2d conv11(int64_t input, int64_t output) {
Options opts(input, output, 1);
opts = opts.stride(1).padding(0).with_bias(false);
return torch::nn::Conv2d(opts);
}
torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) {
Options opts(input, output, 3);
opts = opts.stride(stride).padding(1).with_bias(false).groups(input);
return torch::nn::Conv2d(opts);
}
struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
int64_t stride;
torch::nn::Sequential branch1{nullptr}, branch2{nullptr};
ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride)
: stride(stride) {
if (stride < 1 || stride > 3) {
std::cerr << "illegal stride value'" << std::endl;
assert(false);
}
auto branch_features = oup / 2;
assert(stride != 1 || inp == branch_features << 1);
if (stride > 1) {
branch1 = torch::nn::Sequential(
conv33(inp, inp, stride),
torch::nn::BatchNorm(inp),
conv11(inp, branch_features),
torch::nn::BatchNorm(branch_features),
torch::nn::Functional(modelsimpl::relu_));
}
branch2 = torch::nn::Sequential(
conv11(stride > 1 ? inp : branch_features, branch_features),
torch::nn::BatchNorm(branch_features),
torch::nn::Functional(modelsimpl::relu_),
conv33(branch_features, branch_features, stride),
torch::nn::BatchNorm(branch_features),
conv11(branch_features, branch_features),
torch::nn::BatchNorm(branch_features),
torch::nn::Functional(modelsimpl::relu_));
if (!branch1.is_empty())
register_module("branch1", branch1);
register_module("branch2", branch2);
}
torch::Tensor forward(torch::Tensor x) {
torch::Tensor out;
if (stride == 1) {
auto chunks = x.chunk(2, 1);
out = torch::cat({chunks[0], branch2->forward(chunks[1])}, 1);
} else
out = torch::cat({branch1->forward(x), branch2->forward(x)}, 1);
out = channel_shuffle(out, 2);
return out;
}
};
TORCH_MODULE(ShuffleNetV2InvertedResidual);
ShuffleNetV2Impl::ShuffleNetV2Impl(
const std::vector<int64_t>& stage_repeats,
const std::vector<int64_t>& stage_out_channels,
int64_t num_classes) {
if (stage_repeats.size() != 3) {
std::cerr << "expected stage_repeats as vector of 3 positive ints"
<< std::endl;
assert(false);
}
if (stage_out_channels.size() != 5) {
std::cerr << "expected stage_out_channels as vector of 5 positive ints"
<< std::endl;
assert(false);
}
_stage_out_channels = stage_out_channels;
int64_t input_channels = 3;
auto output_channels = _stage_out_channels[0];
conv1 = torch::nn::Sequential(
torch::nn::Conv2d(Options(input_channels, output_channels, 3)
.stride(2)
.padding(1)
.with_bias(false)),
torch::nn::BatchNorm(output_channels),
torch::nn::Functional(modelsimpl::relu_));
input_channels = output_channels;
std::vector<torch::nn::Sequential> stages = {stage2, stage3, stage4};
for (size_t i = 0; i < stages.size(); ++i) {
auto& seq = stages[i];
auto repeats = stage_repeats[i];
auto output_channels = _stage_out_channels[i + 1];
seq->push_back(
ShuffleNetV2InvertedResidual(input_channels, output_channels, 2));
for (size_t j = 0; j < size_t(repeats - 1); ++j)
seq->push_back(
ShuffleNetV2InvertedResidual(output_channels, output_channels, 1));
input_channels = output_channels;
}
output_channels = _stage_out_channels.back();
conv5 = torch::nn::Sequential(
torch::nn::Conv2d(Options(input_channels, output_channels, 1)
.stride(1)
.padding(0)
.with_bias(false)),
torch::nn::BatchNorm(output_channels),
torch::nn::Functional(modelsimpl::relu_));
fc = torch::nn::Linear(output_channels, num_classes);
register_module("conv1", conv1);
register_module("stage2", stage2);
register_module("stage3", stage3);
register_module("stage4", stage4);
register_module("conv2", conv5);
register_module("fc", fc);
}
torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) {
x = conv1->forward(x);
x = torch::max_pool2d(x, 3, 2, 1);
x = stage2->forward(x);
x = stage3->forward(x);
x = stage4->forward(x);
x = conv5->forward(x);
x = x.mean({2, 3});
x = fc->forward(x);
return x;
}
ShuffleNetV2_x0_5Impl::ShuffleNetV2_x0_5Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 48, 96, 192, 1024}, num_classes) {}
ShuffleNetV2_x1_0Impl::ShuffleNetV2_x1_0Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 116, 232, 464, 1024}, num_classes) {}
ShuffleNetV2_x1_5Impl::ShuffleNetV2_x1_5Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 176, 352, 704, 1024}, num_classes) {}
ShuffleNetV2_x2_0Impl::ShuffleNetV2_x2_0Impl(int64_t num_classes)
: ShuffleNetV2Impl({4, 8, 4}, {24, 244, 488, 976, 2048}, num_classes) {}
} // namespace models
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment