Unverified Commit 6b071be9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Define all C++ model constructors explicit (#2944)

* Making all model constructors explicit.

* formatting.
parent f95b0533
......@@ -11,7 +11,7 @@ namespace models {
struct VISION_API AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr};
AlexNetImpl(int64_t num_classes = 1000);
explicit AlexNetImpl(int64_t num_classes = 1000);
torch::Tensor forward(torch::Tensor x);
};
......
......@@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr};
DenseNetImpl(
explicit DenseNetImpl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
......@@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
};
struct VISION_API DenseNet121Impl : DenseNetImpl {
DenseNet121Impl(
explicit DenseNet121Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
......@@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl {
};
struct VISION_API DenseNet169Impl : DenseNetImpl {
DenseNet169Impl(
explicit DenseNet169Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 32, 32},
......@@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl {
};
struct VISION_API DenseNet201Impl : DenseNetImpl {
DenseNet201Impl(
explicit DenseNet201Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 48, 32},
......@@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl {
};
struct VISION_API DenseNet161Impl : DenseNetImpl {
DenseNet161Impl(
explicit DenseNet161Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 48,
const std::vector<int64_t>& block_config = {6, 12, 36, 24},
......
......@@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options);
explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);
torch::Tensor forward(torch::Tensor x);
};
......@@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module {
torch::nn::Dropout dropout{nullptr};
torch::nn::Linear fc{nullptr};
GoogLeNetImpl(
explicit GoogLeNetImpl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false,
......
......@@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);
explicit BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev = 0.1);
torch::Tensor forward(torch::Tensor x);
};
......@@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module {
struct VISION_API InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;
InceptionBImpl(int64_t in_channels);
explicit InceptionBImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x);
};
......@@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4;
InceptionDImpl(int64_t in_channels);
explicit InceptionDImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x);
};
......@@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module {
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool;
InceptionEImpl(int64_t in_channels);
explicit InceptionEImpl(int64_t in_channels);
torch::Tensor forward(const torch::Tensor& x);
};
......@@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module {
_inceptionimpl::InceptionAux AuxLogits{nullptr};
InceptionV3Impl(
explicit InceptionV3Impl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false);
......
......@@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module {
void _initialize_weights();
MNASNetImpl(double alpha, int64_t num_classes = 1000, double dropout = .2);
explicit MNASNetImpl(
double alpha,
int64_t num_classes = 1000,
double dropout = .2);
torch::Tensor forward(torch::Tensor x);
};
struct MNASNet0_5Impl : MNASNetImpl {
MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
};
struct MNASNet0_75Impl : MNASNetImpl {
MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
};
struct MNASNet1_0Impl : MNASNetImpl {
MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
};
struct MNASNet1_3Impl : MNASNetImpl {
MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
};
TORCH_MODULE(MNASNet);
......
......@@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel;
torch::nn::Sequential features, classifier;
MobileNetV2Impl(
explicit MobileNetV2Impl(
int64_t num_classes = 1000,
double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
......
......@@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module {
int64_t blocks,
int64_t stride = 1);
ResNetImpl(
explicit ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes = 1000,
bool zero_init_residual = false,
......@@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
}
struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet18Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet34Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet50Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet101Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet152Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext50_32x4dImpl(
explicit ResNext50_32x4dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext101_32x8dImpl(
explicit ResNext101_32x8dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet50_2Impl(
explicit WideResNet50_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet101_2Impl(
explicit WideResNet101_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
......
......@@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
};
struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
};
struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
};
struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
};
struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
};
TORCH_MODULE(ShuffleNetV2);
......
......@@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
int64_t num_classes;
torch::nn::Sequential features{nullptr}, classifier{nullptr};
SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
explicit SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
torch::Tensor forward(torch::Tensor x);
};
......@@ -19,7 +19,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
SqueezeNet1_0Impl(int64_t num_classes = 1000);
explicit SqueezeNet1_0Impl(int64_t num_classes = 1000);
};
// SqueezeNet 1.1 model from the official SqueezeNet repo
......@@ -27,7 +27,7 @@ struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
SqueezeNet1_1Impl(int64_t num_classes = 1000);
explicit SqueezeNet1_1Impl(int64_t num_classes = 1000);
};
TORCH_MODULE(SqueezeNet);
......
......@@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module {
void _initialize_weights();
VGGImpl(
explicit VGGImpl(
const torch::nn::Sequential& features,
int64_t num_classes = 1000,
bool initialize_weights = true);
......@@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module {
// VGG 11-layer model (configuration "A")
struct VISION_API VGG11Impl : VGGImpl {
VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG11Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 13-layer model (configuration "B")
struct VISION_API VGG13Impl : VGGImpl {
VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG13Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 16-layer model (configuration "D")
struct VISION_API VGG16Impl : VGGImpl {
VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG16Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 19-layer model (configuration "E")
struct VISION_API VGG19Impl : VGGImpl {
VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG19Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 11-layer model (configuration "A") with batch normalization
struct VISION_API VGG11BNImpl : VGGImpl {
VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG11BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 13-layer model (configuration "B") with batch normalization
struct VISION_API VGG13BNImpl : VGGImpl {
VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG13BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 16-layer model (configuration "D") with batch normalization
struct VISION_API VGG16BNImpl : VGGImpl {
VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG16BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
// VGG 19-layer model (configuration 'E') with batch normalization
struct VISION_API VGG19BNImpl : VGGImpl {
VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG19BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};
TORCH_MODULE(VGG);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment