Commit f6262182 authored by philipNoonan's avatar philipNoonan Committed by Francisco Massa
Browse files

Enabling exporting symbols on windows (#1035)

* Enabling exporting symbols on windows

Small fix to allow for the built library to be used in windows #728

* added macro to allow for exported symbols on windows

* added macro to allow for exported symbols on windows

* removed cmake command

* added dllimport using torchvision_EXPORTS preprocessor
parent c94a1585
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
#define ALEXNET_H #define ALEXNET_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
// AlexNet model architecture from the // AlexNet model architecture from the
// "One weird trick..." <https://arxiv.org/abs/1404.5997> paper. // "One weird trick..." <https://arxiv.org/abs/1404.5997> paper.
struct AlexNetImpl : torch::nn::Module { struct VISION_API AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr}; torch::nn::Sequential features{nullptr}, classifier{nullptr};
AlexNetImpl(int64_t num_classes = 1000); AlexNetImpl(int64_t num_classes = 1000);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DENSENET_H #define DENSENET_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
...@@ -18,7 +19,7 @@ namespace models { ...@@ -18,7 +19,7 @@ namespace models {
// bn_size (int) - multiplicative factor for number of bottle neck layers // bn_size (int) - multiplicative factor for number of bottle neck layers
// (i.e. bn_size * k features in the bottleneck layer) // (i.e. bn_size * k features in the bottleneck layer)
// drop_rate (float) - dropout rate after each dense layer // drop_rate (float) - dropout rate after each dense layer
struct DenseNetImpl : torch::nn::Module { struct VISION_API DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}; torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr}; torch::nn::Linear classifier{nullptr};
...@@ -33,7 +34,7 @@ struct DenseNetImpl : torch::nn::Module { ...@@ -33,7 +34,7 @@ struct DenseNetImpl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct DenseNet121Impl : DenseNetImpl { struct VISION_API DenseNet121Impl : DenseNetImpl {
DenseNet121Impl( DenseNet121Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
...@@ -43,7 +44,7 @@ struct DenseNet121Impl : DenseNetImpl { ...@@ -43,7 +44,7 @@ struct DenseNet121Impl : DenseNetImpl {
double drop_rate = 0); double drop_rate = 0);
}; };
struct DenseNet169Impl : DenseNetImpl { struct VISION_API DenseNet169Impl : DenseNetImpl {
DenseNet169Impl( DenseNet169Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
...@@ -53,7 +54,7 @@ struct DenseNet169Impl : DenseNetImpl { ...@@ -53,7 +54,7 @@ struct DenseNet169Impl : DenseNetImpl {
double drop_rate = 0); double drop_rate = 0);
}; };
struct DenseNet201Impl : DenseNetImpl { struct VISION_API DenseNet201Impl : DenseNetImpl {
DenseNet201Impl( DenseNet201Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
...@@ -63,7 +64,7 @@ struct DenseNet201Impl : DenseNetImpl { ...@@ -63,7 +64,7 @@ struct DenseNet201Impl : DenseNetImpl {
double drop_rate = 0); double drop_rate = 0);
}; };
struct DenseNet161Impl : DenseNetImpl { struct VISION_API DenseNet161Impl : DenseNetImpl {
DenseNet161Impl( DenseNet161Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 48, int64_t growth_rate = 48,
......
#ifndef VISION_GENERAL_H
#define VISION_GENERAL_H
#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif
#endif // VISION_GENERAL_H
\ No newline at end of file
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
#define GOOGLENET_H #define GOOGLENET_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
namespace _googlenetimpl { namespace _googlenetimpl {
struct BasicConv2dImpl : torch::nn::Module { struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr}; torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr}; torch::nn::BatchNorm bn{nullptr};
...@@ -18,7 +19,7 @@ struct BasicConv2dImpl : torch::nn::Module { ...@@ -18,7 +19,7 @@ struct BasicConv2dImpl : torch::nn::Module {
TORCH_MODULE(BasicConv2d); TORCH_MODULE(BasicConv2d);
struct InceptionImpl : torch::nn::Module { struct VISION_API InceptionImpl : torch::nn::Module {
BasicConv2d branch1{nullptr}; BasicConv2d branch1{nullptr};
torch::nn::Sequential branch2, branch3, branch4; torch::nn::Sequential branch2, branch3, branch4;
...@@ -36,7 +37,7 @@ struct InceptionImpl : torch::nn::Module { ...@@ -36,7 +37,7 @@ struct InceptionImpl : torch::nn::Module {
TORCH_MODULE(Inception); TORCH_MODULE(Inception);
struct InceptionAuxImpl : torch::nn::Module { struct VISION_API InceptionAuxImpl : torch::nn::Module {
BasicConv2d conv{nullptr}; BasicConv2d conv{nullptr};
torch::nn::Linear fc1{nullptr}, fc2{nullptr}; torch::nn::Linear fc1{nullptr}, fc2{nullptr};
...@@ -49,13 +50,13 @@ TORCH_MODULE(InceptionAux); ...@@ -49,13 +50,13 @@ TORCH_MODULE(InceptionAux);
} // namespace _googlenetimpl } // namespace _googlenetimpl
struct GoogLeNetOutput { struct VISION_API GoogLeNetOutput {
torch::Tensor output; torch::Tensor output;
torch::Tensor aux1; torch::Tensor aux1;
torch::Tensor aux2; torch::Tensor aux2;
}; };
struct GoogLeNetImpl : torch::nn::Module { struct VISION_API GoogLeNetImpl : torch::nn::Module {
bool aux_logits, transform_input; bool aux_logits, transform_input;
_googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; _googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
#define INCEPTION_H #define INCEPTION_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
namespace _inceptionimpl { namespace _inceptionimpl {
struct BasicConv2dImpl : torch::nn::Module { struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr}; torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr}; torch::nn::BatchNorm bn{nullptr};
...@@ -17,7 +18,7 @@ struct BasicConv2dImpl : torch::nn::Module { ...@@ -17,7 +18,7 @@ struct BasicConv2dImpl : torch::nn::Module {
TORCH_MODULE(BasicConv2d); TORCH_MODULE(BasicConv2d);
struct InceptionAImpl : torch::nn::Module { struct VISION_API InceptionAImpl : torch::nn::Module {
BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1, BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
branch3x3dbl_2, branch3x3dbl_3, branch_pool; branch3x3dbl_2, branch3x3dbl_3, branch_pool;
...@@ -26,7 +27,7 @@ struct InceptionAImpl : torch::nn::Module { ...@@ -26,7 +27,7 @@ struct InceptionAImpl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct InceptionBImpl : torch::nn::Module { struct VISION_API InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3; BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;
InceptionBImpl(int64_t in_channels); InceptionBImpl(int64_t in_channels);
...@@ -34,7 +35,7 @@ struct InceptionBImpl : torch::nn::Module { ...@@ -34,7 +35,7 @@ struct InceptionBImpl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct InceptionCImpl : torch::nn::Module { struct VISION_API InceptionCImpl : torch::nn::Module {
BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr}, BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr}, branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr}, branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
...@@ -45,7 +46,7 @@ struct InceptionCImpl : torch::nn::Module { ...@@ -45,7 +46,7 @@ struct InceptionCImpl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct InceptionDImpl : torch::nn::Module { struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2, BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4; branch7x7x3_3, branch7x7x3_4;
...@@ -54,7 +55,7 @@ struct InceptionDImpl : torch::nn::Module { ...@@ -54,7 +55,7 @@ struct InceptionDImpl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct InceptionEImpl : torch::nn::Module { struct VISION_API InceptionEImpl : torch::nn::Module {
BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b, BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool; branch_pool;
...@@ -64,7 +65,7 @@ struct InceptionEImpl : torch::nn::Module { ...@@ -64,7 +65,7 @@ struct InceptionEImpl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct InceptionAuxImpl : torch::nn::Module { struct VISION_API InceptionAuxImpl : torch::nn::Module {
BasicConv2d conv0; BasicConv2d conv0;
BasicConv2d conv1; BasicConv2d conv1;
torch::nn::Linear fc; torch::nn::Linear fc;
...@@ -83,7 +84,7 @@ TORCH_MODULE(InceptionAux); ...@@ -83,7 +84,7 @@ TORCH_MODULE(InceptionAux);
} // namespace _inceptionimpl } // namespace _inceptionimpl
struct InceptionV3Output { struct VISION_API InceptionV3Output {
torch::Tensor output; torch::Tensor output;
torch::Tensor aux; torch::Tensor aux;
}; };
...@@ -91,7 +92,7 @@ struct InceptionV3Output { ...@@ -91,7 +92,7 @@ struct InceptionV3Output {
// Inception v3 model architecture from // Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision" //"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567> //<http://arxiv.org/abs/1512.00567>
struct InceptionV3Impl : torch::nn::Module { struct VISION_API InceptionV3Impl : torch::nn::Module {
bool aux_logits, transform_input; bool aux_logits, transform_input;
_inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr}, _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#define MOBILENET_H #define MOBILENET_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
struct MobileNetV2Impl : torch::nn::Module { struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel; int64_t last_channel;
torch::nn::Sequential features, classifier; torch::nn::Sequential features, classifier;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define RESNET_H #define RESNET_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
...@@ -19,7 +20,7 @@ torch::nn::Conv2d conv3x3( ...@@ -19,7 +20,7 @@ torch::nn::Conv2d conv3x3(
// 1x1 convolution // 1x1 convolution
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1); torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1);
struct BasicBlock : torch::nn::Module { struct VISION_API BasicBlock : torch::nn::Module {
template <typename Block> template <typename Block>
friend struct vision::models::ResNetImpl; friend struct vision::models::ResNetImpl;
...@@ -42,7 +43,7 @@ struct BasicBlock : torch::nn::Module { ...@@ -42,7 +43,7 @@ struct BasicBlock : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct Bottleneck : torch::nn::Module { struct VISION_API Bottleneck : torch::nn::Module {
template <typename Block> template <typename Block>
friend struct vision::models::ResNetImpl; friend struct vision::models::ResNetImpl;
...@@ -184,40 +185,40 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) { ...@@ -184,40 +185,40 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
return x; return x;
} }
struct ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> { struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false); ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
}; };
struct ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> { struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false); ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
}; };
struct ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false); ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
}; };
struct ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false); ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
}; };
struct ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false); ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
}; };
struct ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext50_32x4dImpl( ResNext50_32x4dImpl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
struct ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext101_32x8dImpl( ResNext101_32x8dImpl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
template <typename Block> template <typename Block>
struct ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> { struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder; using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
}; };
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
#define SHUFFLENETV2_H #define SHUFFLENETV2_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
struct ShuffleNetV2Impl : torch::nn::Module { struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
std::vector<int64_t> _stage_out_channels; std::vector<int64_t> _stage_out_channels;
torch::nn::Sequential conv1{nullptr}, stage2, stage3, stage4, conv5{nullptr}; torch::nn::Sequential conv1{nullptr}, stage2, stage3, stage4, conv5{nullptr};
torch::nn::Linear fc{nullptr}; torch::nn::Linear fc{nullptr};
...@@ -19,19 +20,19 @@ struct ShuffleNetV2Impl : torch::nn::Module { ...@@ -19,19 +20,19 @@ struct ShuffleNetV2Impl : torch::nn::Module {
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000); ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
}; };
struct ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000); ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
}; };
struct ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000); ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
}; };
struct ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000); ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
}; };
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#define SQUEEZENET_H #define SQUEEZENET_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
struct SqueezeNetImpl : torch::nn::Module { struct VISION_API SqueezeNetImpl : torch::nn::Module {
int64_t num_classes; int64_t num_classes;
torch::nn::Sequential features{nullptr}, classifier{nullptr}; torch::nn::Sequential features{nullptr}, classifier{nullptr};
...@@ -17,7 +18,7 @@ struct SqueezeNetImpl : torch::nn::Module { ...@@ -17,7 +18,7 @@ struct SqueezeNetImpl : torch::nn::Module {
// SqueezeNet model architecture from the "SqueezeNet: AlexNet-level // SqueezeNet model architecture from the "SqueezeNet: AlexNet-level
// accuracy with 50x fewer parameters and <0.5MB model size" // accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper. // <https://arxiv.org/abs/1602.07360> paper.
struct SqueezeNet1_0Impl : SqueezeNetImpl { struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
SqueezeNet1_0Impl(int64_t num_classes = 1000); SqueezeNet1_0Impl(int64_t num_classes = 1000);
}; };
...@@ -25,7 +26,7 @@ struct SqueezeNet1_0Impl : SqueezeNetImpl { ...@@ -25,7 +26,7 @@ struct SqueezeNet1_0Impl : SqueezeNetImpl {
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>. // <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters // SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy. // than SqueezeNet 1.0, without sacrificing accuracy.
struct SqueezeNet1_1Impl : SqueezeNetImpl { struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
SqueezeNet1_1Impl(int64_t num_classes = 1000); SqueezeNet1_1Impl(int64_t num_classes = 1000);
}; };
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#define VGG_H #define VGG_H
#include <torch/torch.h> #include <torch/torch.h>
#include "general.h"
namespace vision { namespace vision {
namespace models { namespace models {
struct VGGImpl : torch::nn::Module { struct VISION_API VGGImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr}; torch::nn::Sequential features{nullptr}, classifier{nullptr};
void _initialize_weights(); void _initialize_weights();
...@@ -19,42 +20,42 @@ struct VGGImpl : torch::nn::Module { ...@@ -19,42 +20,42 @@ struct VGGImpl : torch::nn::Module {
}; };
// VGG 11-layer model (configuration "A") // VGG 11-layer model (configuration "A")
struct VGG11Impl : VGGImpl { struct VISION_API VGG11Impl : VGGImpl {
VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true); VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 13-layer model (configuration "B") // VGG 13-layer model (configuration "B")
struct VGG13Impl : VGGImpl { struct VISION_API VGG13Impl : VGGImpl {
VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true); VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 16-layer model (configuration "D") // VGG 16-layer model (configuration "D")
struct VGG16Impl : VGGImpl { struct VISION_API VGG16Impl : VGGImpl {
VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true); VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 19-layer model (configuration "E") // VGG 19-layer model (configuration "E")
struct VGG19Impl : VGGImpl { struct VISION_API VGG19Impl : VGGImpl {
VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true); VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 11-layer model (configuration "A") with batch normalization // VGG 11-layer model (configuration "A") with batch normalization
struct VGG11BNImpl : VGGImpl { struct VISION_API VGG11BNImpl : VGGImpl {
VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 13-layer model (configuration "B") with batch normalization // VGG 13-layer model (configuration "B") with batch normalization
struct VGG13BNImpl : VGGImpl { struct VISION_API VGG13BNImpl : VGGImpl {
VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 16-layer model (configuration "D") with batch normalization // VGG 16-layer model (configuration "D") with batch normalization
struct VGG16BNImpl : VGGImpl { struct VISION_API VGG16BNImpl : VGGImpl {
VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
// VGG 19-layer model (configuration 'E') with batch normalization // VGG 19-layer model (configuration 'E') with batch normalization
struct VGG19BNImpl : VGGImpl { struct VISION_API VGG19BNImpl : VGGImpl {
VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment