Unverified Commit 6b071be9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Define all C++ model constructors explicit (#2944)

* Making all model constructors explicit.

* formatting.
parent f95b0533
...@@ -11,7 +11,7 @@ namespace models { ...@@ -11,7 +11,7 @@ namespace models {
struct VISION_API AlexNetImpl : torch::nn::Module { struct VISION_API AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr}; torch::nn::Sequential features{nullptr}, classifier{nullptr};
AlexNetImpl(int64_t num_classes = 1000); explicit AlexNetImpl(int64_t num_classes = 1000);
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
......
...@@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module { ...@@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}; torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr}; torch::nn::Linear classifier{nullptr};
DenseNetImpl( explicit DenseNetImpl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16}, const std::vector<int64_t>& block_config = {6, 12, 24, 16},
...@@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module { ...@@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
}; };
struct VISION_API DenseNet121Impl : DenseNetImpl { struct VISION_API DenseNet121Impl : DenseNetImpl {
DenseNet121Impl( explicit DenseNet121Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16}, const std::vector<int64_t>& block_config = {6, 12, 24, 16},
...@@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl { ...@@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl {
}; };
struct VISION_API DenseNet169Impl : DenseNetImpl { struct VISION_API DenseNet169Impl : DenseNetImpl {
DenseNet169Impl( explicit DenseNet169Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 32, 32}, const std::vector<int64_t>& block_config = {6, 12, 32, 32},
...@@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl { ...@@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl {
}; };
struct VISION_API DenseNet201Impl : DenseNetImpl { struct VISION_API DenseNet201Impl : DenseNetImpl {
DenseNet201Impl( explicit DenseNet201Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 32, int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 48, 32}, const std::vector<int64_t>& block_config = {6, 12, 48, 32},
...@@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl { ...@@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl {
}; };
struct VISION_API DenseNet161Impl : DenseNetImpl { struct VISION_API DenseNet161Impl : DenseNetImpl {
DenseNet161Impl( explicit DenseNet161Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
int64_t growth_rate = 48, int64_t growth_rate = 48,
const std::vector<int64_t>& block_config = {6, 12, 36, 24}, const std::vector<int64_t>& block_config = {6, 12, 36, 24},
......
...@@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module { ...@@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr}; torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr}; torch::nn::BatchNorm2d bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options); explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
...@@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module { ...@@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module {
torch::nn::Dropout dropout{nullptr}; torch::nn::Dropout dropout{nullptr};
torch::nn::Linear fc{nullptr}; torch::nn::Linear fc{nullptr};
GoogLeNetImpl( explicit GoogLeNetImpl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool aux_logits = true, bool aux_logits = true,
bool transform_input = false, bool transform_input = false,
......
...@@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module { ...@@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr}; torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr}; torch::nn::BatchNorm2d bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1); explicit BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev = 0.1);
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
...@@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module { ...@@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module {
struct VISION_API InceptionBImpl : torch::nn::Module { struct VISION_API InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3; BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;
InceptionBImpl(int64_t in_channels); explicit InceptionBImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x); torch::Tensor forward(const torch::Tensor& x);
}; };
...@@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module { ...@@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2, BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4; branch7x7x3_3, branch7x7x3_4;
InceptionDImpl(int64_t in_channels); explicit InceptionDImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x); torch::Tensor forward(const torch::Tensor& x);
}; };
...@@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module { ...@@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module {
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool; branch_pool;
InceptionEImpl(int64_t in_channels); explicit InceptionEImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x); torch::Tensor forward(const torch::Tensor& x);
}; };
...@@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module { ...@@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module {
_inceptionimpl::InceptionAux AuxLogits{nullptr}; _inceptionimpl::InceptionAux AuxLogits{nullptr};
InceptionV3Impl( explicit InceptionV3Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool aux_logits = true, bool aux_logits = true,
bool transform_input = false); bool transform_input = false);
......
...@@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module { ...@@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module {
void _initialize_weights(); void _initialize_weights();
MNASNetImpl(double alpha, int64_t num_classes = 1000, double dropout = .2); explicit MNASNetImpl(
double alpha,
int64_t num_classes = 1000,
double dropout = .2);
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
struct MNASNet0_5Impl : MNASNetImpl { struct MNASNet0_5Impl : MNASNetImpl {
MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2); explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
}; };
struct MNASNet0_75Impl : MNASNetImpl { struct MNASNet0_75Impl : MNASNetImpl {
MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2); explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
}; };
struct MNASNet1_0Impl : MNASNetImpl { struct MNASNet1_0Impl : MNASNetImpl {
MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2); explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
}; };
struct MNASNet1_3Impl : MNASNetImpl { struct MNASNet1_3Impl : MNASNetImpl {
MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2); explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
}; };
TORCH_MODULE(MNASNet); TORCH_MODULE(MNASNet);
......
...@@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module { ...@@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel; int64_t last_channel;
torch::nn::Sequential features, classifier; torch::nn::Sequential features, classifier;
MobileNetV2Impl( explicit MobileNetV2Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
double width_mult = 1.0, double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {}, std::vector<std::vector<int64_t>> inverted_residual_settings = {},
......
...@@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module { ...@@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module {
int64_t blocks, int64_t blocks,
int64_t stride = 1); int64_t stride = 1);
ResNetImpl( explicit ResNetImpl(
const std::vector<int>& layers, const std::vector<int>& layers,
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false, bool zero_init_residual = false,
...@@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) { ...@@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
} }
struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> { struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false); explicit ResNet18Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
}; };
struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> { struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false); explicit ResNet34Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
}; };
struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false); explicit ResNet50Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
}; };
struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false); explicit ResNet101Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
}; };
struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false); explicit ResNet152Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
}; };
struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext50_32x4dImpl( explicit ResNext50_32x4dImpl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext101_32x8dImpl( explicit ResNext101_32x8dImpl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet50_2Impl( explicit WideResNet50_2Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> { struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet101_2Impl( explicit WideResNet101_2Impl(
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
......
...@@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module { ...@@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
}; };
struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000); explicit ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
}; };
struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000); explicit ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
}; };
struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000); explicit ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
}; };
struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl { struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000); explicit ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
}; };
TORCH_MODULE(ShuffleNetV2); TORCH_MODULE(ShuffleNetV2);
......
...@@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module { ...@@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
int64_t num_classes; int64_t num_classes;
torch::nn::Sequential features{nullptr}, classifier{nullptr}; torch::nn::Sequential features{nullptr}, classifier{nullptr};
SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000); explicit SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
...@@ -19,7 +19,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module { ...@@ -19,7 +19,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
// accuracy with 50x fewer parameters and <0.5MB model size" // accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper. // <https://arxiv.org/abs/1602.07360> paper.
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl { struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
SqueezeNet1_0Impl(int64_t num_classes = 1000); explicit SqueezeNet1_0Impl(int64_t num_classes = 1000);
}; };
// SqueezeNet 1.1 model from the official SqueezeNet repo // SqueezeNet 1.1 model from the official SqueezeNet repo
...@@ -27,7 +27,7 @@ struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl { ...@@ -27,7 +27,7 @@ struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters // SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy. // than SqueezeNet 1.0, without sacrificing accuracy.
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl { struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
SqueezeNet1_1Impl(int64_t num_classes = 1000); explicit SqueezeNet1_1Impl(int64_t num_classes = 1000);
}; };
TORCH_MODULE(SqueezeNet); TORCH_MODULE(SqueezeNet);
......
...@@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module { ...@@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module {
void _initialize_weights(); void _initialize_weights();
VGGImpl( explicit VGGImpl(
const torch::nn::Sequential& features, const torch::nn::Sequential& features,
int64_t num_classes = 1000, int64_t num_classes = 1000,
bool initialize_weights = true); bool initialize_weights = true);
...@@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module { ...@@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module {
// VGG 11-layer model (configuration "A") // VGG 11-layer model (configuration "A")
struct VISION_API VGG11Impl : VGGImpl { struct VISION_API VGG11Impl : VGGImpl {
VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG11Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 13-layer model (configuration "B") // VGG 13-layer model (configuration "B")
struct VISION_API VGG13Impl : VGGImpl { struct VISION_API VGG13Impl : VGGImpl {
VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG13Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 16-layer model (configuration "D") // VGG 16-layer model (configuration "D")
struct VISION_API VGG16Impl : VGGImpl { struct VISION_API VGG16Impl : VGGImpl {
VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG16Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 19-layer model (configuration "E") // VGG 19-layer model (configuration "E")
struct VISION_API VGG19Impl : VGGImpl { struct VISION_API VGG19Impl : VGGImpl {
VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG19Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 11-layer model (configuration "A") with batch normalization // VGG 11-layer model (configuration "A") with batch normalization
struct VISION_API VGG11BNImpl : VGGImpl { struct VISION_API VGG11BNImpl : VGGImpl {
VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG11BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 13-layer model (configuration "B") with batch normalization // VGG 13-layer model (configuration "B") with batch normalization
struct VISION_API VGG13BNImpl : VGGImpl { struct VISION_API VGG13BNImpl : VGGImpl {
VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG13BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 16-layer model (configuration "D") with batch normalization // VGG 16-layer model (configuration "D") with batch normalization
struct VISION_API VGG16BNImpl : VGGImpl { struct VISION_API VGG16BNImpl : VGGImpl {
VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG16BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
// VGG 19-layer model (configuration 'E') with batch normalization // VGG 19-layer model (configuration 'E') with batch normalization
struct VISION_API VGG19BNImpl : VGGImpl { struct VISION_API VGG19BNImpl : VGGImpl {
VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); explicit VGG19BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
}; };
TORCH_MODULE(VGG); TORCH_MODULE(VGG);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment