#include "densenet.h" #include "modelsimpl.h" namespace vision { namespace models { using Options = torch::nn::Conv2dOptions; struct _DenseLayerImpl : torch::nn::SequentialImpl { double drop_rate; _DenseLayerImpl( int64_t num_input_features, int64_t growth_rate, int64_t bn_size, double drop_rate) : drop_rate(drop_rate) { push_back("norm1", torch::nn::BatchNorm2d(num_input_features)); push_back("relu1", torch::nn::Functional(modelsimpl::relu_)); push_back( "conv1", torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1) .stride(1) .bias(false))); push_back("norm2", torch::nn::BatchNorm2d(bn_size * growth_rate)); push_back("relu2", torch::nn::Functional(modelsimpl::relu_)); push_back( "conv2", torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3) .stride(1) .padding(1) .bias(false))); } torch::Tensor forward(torch::Tensor x) { auto new_features = torch::nn::SequentialImpl::forward(x); if (drop_rate > 0) new_features = torch::dropout(new_features, drop_rate, this->is_training()); return torch::cat({x, new_features}, 1); } }; TORCH_MODULE(_DenseLayer); struct _DenseBlockImpl : torch::nn::SequentialImpl { _DenseBlockImpl( int64_t num_layers, int64_t num_input_features, int64_t bn_size, int64_t growth_rate, double drop_rate) { for (int64_t i = 0; i < num_layers; ++i) { auto layer = _DenseLayer( num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate); push_back("denselayer" + std::to_string(i + 1), layer); } } torch::Tensor forward(torch::Tensor x) { return torch::nn::SequentialImpl::forward(x); } }; TORCH_MODULE(_DenseBlock); struct _TransitionImpl : torch::nn::SequentialImpl { _TransitionImpl(int64_t num_input_features, int64_t num_output_features) { push_back("norm", torch::nn::BatchNorm2d(num_input_features)); push_back("relu ", torch::nn::Functional(modelsimpl::relu_)); push_back( "conv", torch::nn::Conv2d(Options(num_input_features, num_output_features, 1) .stride(1) .bias(false))); push_back("pool", torch::nn::Functional([](const torch::Tensor& input) { return torch::avg_pool2d(input, 2, 2, 0, false, true); })); } torch::Tensor forward(torch::Tensor x) { return torch::nn::SequentialImpl::forward(x); } }; TORCH_MODULE(_Transition); DenseNetImpl::DenseNetImpl( int64_t num_classes, int64_t growth_rate, const std::vector& block_config, int64_t num_init_features, int64_t bn_size, double drop_rate) { // First convolution features = torch::nn::Sequential(); features->push_back( "conv0", torch::nn::Conv2d( Options(3, num_init_features, 7).stride(2).padding(3).bias(false))); features->push_back("norm0", torch::nn::BatchNorm2d(num_init_features)); features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_)); features->push_back( "pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false)); // Each denseblock auto num_features = num_init_features; for (size_t i = 0; i < block_config.size(); ++i) { auto num_layers = block_config[i]; _DenseBlock block( num_layers, num_features, bn_size, growth_rate, drop_rate); features->push_back("denseblock" + std::to_string(i + 1), block); num_features = num_features + num_layers * growth_rate; if (i != block_config.size() - 1) { auto trans = _Transition(num_features, num_features / 2); features->push_back("transition" + std::to_string(i + 1), trans); num_features = num_features / 2; } } // Final batch norm features->push_back("norm5", torch::nn::BatchNorm2d(num_features)); // Linear layer classifier = torch::nn::Linear(num_features, num_classes); register_module("features", features); register_module("classifier", classifier); // Official init from torch repo. for (auto& module : modules(/*include_self=*/false)) { if (auto M = dynamic_cast(module.get())) torch::nn::init::kaiming_normal_(M->weight); else if (auto M = dynamic_cast(module.get())) { torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->bias, 0); } else if (auto M = dynamic_cast(module.get())) torch::nn::init::constant_(M->bias, 0); } modelsimpl::deprecation_warning(); } torch::Tensor DenseNetImpl::forward(torch::Tensor x) { auto features = this->features->forward(x); auto out = torch::relu_(features); out = torch::adaptive_avg_pool2d(out, {1, 1}); out = out.view({features.size(0), -1}); out = this->classifier->forward(out); return out; } DenseNet121Impl::DenseNet121Impl( int64_t num_classes, int64_t growth_rate, const std::vector& block_config, int64_t num_init_features, int64_t bn_size, double drop_rate) : DenseNetImpl( num_classes, growth_rate, block_config, num_init_features, bn_size, drop_rate) {} DenseNet169Impl::DenseNet169Impl( int64_t num_classes, int64_t growth_rate, const std::vector& block_config, int64_t num_init_features, int64_t bn_size, double drop_rate) : DenseNetImpl( num_classes, growth_rate, block_config, num_init_features, bn_size, drop_rate) {} DenseNet201Impl::DenseNet201Impl( int64_t num_classes, int64_t growth_rate, const std::vector& block_config, int64_t num_init_features, int64_t bn_size, double drop_rate) : DenseNetImpl( num_classes, growth_rate, block_config, num_init_features, bn_size, drop_rate) {} DenseNet161Impl::DenseNet161Impl( int64_t num_classes, int64_t growth_rate, const std::vector& block_config, int64_t num_init_features, int64_t bn_size, double drop_rate) : DenseNetImpl( num_classes, growth_rate, block_config, num_init_features, bn_size, drop_rate) {} } // namespace models } // namespace vision