#include "mobilenet.h" #include "modelsimpl.h" namespace vision { namespace models { using Options = torch::nn::Conv2dOptions; int64_t make_divisible( double value, int64_t divisor, c10::optional min_value = {}) { if (!min_value.has_value()) min_value = divisor; auto new_value = std::max( min_value.value(), (int64_t(value + divisor / 2) / divisor) * divisor); if (new_value < .9 * value) new_value += divisor; return new_value; } struct ConvBNReLUImpl : torch::nn::SequentialImpl { ConvBNReLUImpl( int64_t in_planes, int64_t out_planes, int64_t kernel_size = 3, int64_t stride = 1, int64_t groups = 1) { auto padding = (kernel_size - 1) / 2; push_back(torch::nn::Conv2d(Options(in_planes, out_planes, kernel_size) .stride(stride) .padding(padding) .groups(groups) .bias(false))); push_back(torch::nn::BatchNorm2d(out_planes)); push_back(torch::nn::Functional(modelsimpl::relu6_)); } torch::Tensor forward(torch::Tensor x) { return torch::nn::SequentialImpl::forward(x); } }; TORCH_MODULE(ConvBNReLU); struct MobileNetInvertedResidualImpl : torch::nn::Module { int64_t stride; bool use_res_connect; torch::nn::Sequential conv; MobileNetInvertedResidualImpl( int64_t input, int64_t output, int64_t stride, double expand_ratio) : stride(stride), use_res_connect(stride == 1 && input == output) { auto double_compare = [](double a, double b) { return double(std::abs(a - b)) < std::numeric_limits::epsilon(); }; TORCH_CHECK(stride == 1 || stride == 2); auto hidden_dim = int64_t(std::round(input * expand_ratio)); if (!double_compare(expand_ratio, 1)) conv->push_back(ConvBNReLU(input, hidden_dim, 1)); conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim)); conv->push_back(torch::nn::Conv2d( Options(hidden_dim, output, 1).stride(1).padding(0).bias(false))); conv->push_back(torch::nn::BatchNorm2d(output)); register_module("conv", conv); } torch::Tensor forward(torch::Tensor x) { if (use_res_connect) return x + conv->forward(x); return conv->forward(x); } }; TORCH_MODULE(MobileNetInvertedResidual); MobileNetV2Impl::MobileNetV2Impl( int64_t num_classes, double width_mult, std::vector> inverted_residual_settings, int64_t round_nearest) { using Block = MobileNetInvertedResidual; int64_t input_channel = 32; int64_t last_channel = 1280; if (inverted_residual_settings.empty()) inverted_residual_settings = { // t, c, n, s {1, 16, 1, 1}, {6, 24, 2, 2}, {6, 32, 3, 2}, {6, 64, 4, 2}, {6, 96, 3, 1}, {6, 160, 3, 2}, {6, 320, 1, 1}, }; TORCH_CHECK( inverted_residual_settings[0].size() == 4, "inverted_residual_settings should contain 4-element vectors"); input_channel = make_divisible(input_channel * width_mult, round_nearest); this->last_channel = make_divisible(last_channel * std::max(1.0, width_mult), round_nearest); features->push_back(ConvBNReLU(3, input_channel, 3, 2)); for (auto setting : inverted_residual_settings) { auto output_channel = make_divisible(setting[1] * width_mult, round_nearest); for (int64_t i = 0; i < setting[2]; ++i) { auto stride = i == 0 ? setting[3] : 1; features->push_back( Block(input_channel, output_channel, stride, setting[0])); input_channel = output_channel; } } features->push_back(ConvBNReLU(input_channel, this->last_channel, 1)); classifier->push_back(torch::nn::Dropout(0.2)); classifier->push_back(torch::nn::Linear(this->last_channel, num_classes)); register_module("features", features); register_module("classifier", classifier); for (auto& module : modules(/*include_self=*/false)) { if (auto M = dynamic_cast(module.get())) { torch::nn::init::kaiming_normal_( M->weight, 0, torch::kFanOut); if (M->options.bias()) torch::nn::init::zeros_(M->bias); } else if (auto M = dynamic_cast(module.get())) { torch::nn::init::ones_(M->weight); torch::nn::init::zeros_(M->bias); } else if (auto M = dynamic_cast(module.get())) { torch::nn::init::normal_(M->weight, 0, 0.01); torch::nn::init::zeros_(M->bias); } } } torch::Tensor MobileNetV2Impl::forward(at::Tensor x) { x = features->forward(x); x = x.mean({2, 3}); x = classifier->forward(x); return x; } } // namespace models } // namespace vision