#pragma once #include #include "../macros.h" namespace vision { namespace models { namespace _inceptionimpl { struct VISION_API BasicConv2dImpl : torch::nn::Module { torch::nn::Conv2d conv{nullptr}; torch::nn::BatchNorm2d bn{nullptr}; explicit BasicConv2dImpl( torch::nn::Conv2dOptions options, double std_dev = 0.1); torch::Tensor forward(torch::Tensor x); }; TORCH_MODULE(BasicConv2d); struct VISION_API InceptionAImpl : torch::nn::Module { BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3, branch_pool; InceptionAImpl(int64_t in_channels, int64_t pool_features); torch::Tensor forward(const torch::Tensor& x); }; struct VISION_API InceptionBImpl : torch::nn::Module { BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3; explicit InceptionBImpl(int64_t in_channels); torch::Tensor forward(const torch::Tensor& x); }; struct VISION_API InceptionCImpl : torch::nn::Module { BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr}, branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr}, branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr}, branch_pool{nullptr}; InceptionCImpl(int64_t in_channels, int64_t channels_7x7); torch::Tensor forward(const torch::Tensor& x); }; struct VISION_API InceptionDImpl : torch::nn::Module { BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2, branch7x7x3_3, branch7x7x3_4; explicit InceptionDImpl(int64_t in_channels); torch::Tensor forward(const torch::Tensor& x); }; struct VISION_API InceptionEImpl : torch::nn::Module { BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b, branch_pool; explicit InceptionEImpl(int64_t in_channels); torch::Tensor forward(const torch::Tensor& x); }; struct VISION_API InceptionAuxImpl : torch::nn::Module { BasicConv2d conv0; BasicConv2d conv1; torch::nn::Linear fc; InceptionAuxImpl(int64_t in_channels, int64_t num_classes); torch::Tensor forward(torch::Tensor x); }; TORCH_MODULE(InceptionA); TORCH_MODULE(InceptionB); TORCH_MODULE(InceptionC); TORCH_MODULE(InceptionD); TORCH_MODULE(InceptionE); TORCH_MODULE(InceptionAux); } // namespace _inceptionimpl struct VISION_API InceptionV3Output { torch::Tensor output; torch::Tensor aux; }; // Inception v3 model architecture from //"Rethinking the Inception Architecture for Computer Vision" // struct VISION_API InceptionV3Impl : torch::nn::Module { bool aux_logits, transform_input; _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr}, Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr}; _inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr}, Mixed_5d{nullptr}; _inceptionimpl::InceptionB Mixed_6a{nullptr}; _inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr}, Mixed_6d{nullptr}, Mixed_6e{nullptr}; _inceptionimpl::InceptionD Mixed_7a{nullptr}; _inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr}; torch::nn::Linear fc{nullptr}; _inceptionimpl::InceptionAux AuxLogits{nullptr}; explicit InceptionV3Impl( int64_t num_classes = 1000, bool aux_logits = true, bool transform_input = false); InceptionV3Output forward(torch::Tensor x); }; TORCH_MODULE(InceptionV3); } // namespace models } // namespace vision