Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6b071be9
Unverified
Commit
6b071be9
authored
Nov 03, 2020
by
Vasilis Vryniotis
Committed by
GitHub
Nov 03, 2020
Browse files
Define all C++ model constructors explicit (#2944)
* Making all model constructors explicit. * formatting.
parent
f95b0533
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
76 additions
and
45 deletions
+76
-45
torchvision/csrc/models/alexnet.h
torchvision/csrc/models/alexnet.h
+1
-1
torchvision/csrc/models/densenet.h
torchvision/csrc/models/densenet.h
+5
-5
torchvision/csrc/models/googlenet.h
torchvision/csrc/models/googlenet.h
+2
-2
torchvision/csrc/models/inception.h
torchvision/csrc/models/inception.h
+7
-5
torchvision/csrc/models/mnasnet.h
torchvision/csrc/models/mnasnet.h
+8
-5
torchvision/csrc/models/mobilenet.h
torchvision/csrc/models/mobilenet.h
+1
-1
torchvision/csrc/models/resnet.h
torchvision/csrc/models/resnet.h
+20
-10
torchvision/csrc/models/shufflenetv2.h
torchvision/csrc/models/shufflenetv2.h
+4
-4
torchvision/csrc/models/squeezenet.h
torchvision/csrc/models/squeezenet.h
+3
-3
torchvision/csrc/models/vgg.h
torchvision/csrc/models/vgg.h
+25
-9
No files found.
torchvision/csrc/models/alexnet.h
View file @
6b071be9
...
@@ -11,7 +11,7 @@ namespace models {
...
@@ -11,7 +11,7 @@ namespace models {
struct
VISION_API
AlexNetImpl
:
torch
::
nn
::
Module
{
struct
VISION_API
AlexNetImpl
:
torch
::
nn
::
Module
{
torch
::
nn
::
Sequential
features
{
nullptr
},
classifier
{
nullptr
};
torch
::
nn
::
Sequential
features
{
nullptr
},
classifier
{
nullptr
};
AlexNetImpl
(
int64_t
num_classes
=
1000
);
explicit
AlexNetImpl
(
int64_t
num_classes
=
1000
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
};
};
...
...
torchvision/csrc/models/densenet.h
View file @
6b071be9
...
@@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
...
@@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
torch
::
nn
::
Sequential
features
{
nullptr
};
torch
::
nn
::
Sequential
features
{
nullptr
};
torch
::
nn
::
Linear
classifier
{
nullptr
};
torch
::
nn
::
Linear
classifier
{
nullptr
};
DenseNetImpl
(
explicit
DenseNetImpl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
int64_t
growth_rate
=
32
,
int64_t
growth_rate
=
32
,
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
24
,
16
},
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
24
,
16
},
...
@@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
...
@@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
};
};
struct
VISION_API
DenseNet121Impl
:
DenseNetImpl
{
struct
VISION_API
DenseNet121Impl
:
DenseNetImpl
{
DenseNet121Impl
(
explicit
DenseNet121Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
int64_t
growth_rate
=
32
,
int64_t
growth_rate
=
32
,
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
24
,
16
},
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
24
,
16
},
...
@@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl {
...
@@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl {
};
};
struct
VISION_API
DenseNet169Impl
:
DenseNetImpl
{
struct
VISION_API
DenseNet169Impl
:
DenseNetImpl
{
DenseNet169Impl
(
explicit
DenseNet169Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
int64_t
growth_rate
=
32
,
int64_t
growth_rate
=
32
,
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
32
,
32
},
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
32
,
32
},
...
@@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl {
...
@@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl {
};
};
struct
VISION_API
DenseNet201Impl
:
DenseNetImpl
{
struct
VISION_API
DenseNet201Impl
:
DenseNetImpl
{
DenseNet201Impl
(
explicit
DenseNet201Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
int64_t
growth_rate
=
32
,
int64_t
growth_rate
=
32
,
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
48
,
32
},
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
48
,
32
},
...
@@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl {
...
@@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl {
};
};
struct
VISION_API
DenseNet161Impl
:
DenseNetImpl
{
struct
VISION_API
DenseNet161Impl
:
DenseNetImpl
{
DenseNet161Impl
(
explicit
DenseNet161Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
int64_t
growth_rate
=
48
,
int64_t
growth_rate
=
48
,
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
36
,
24
},
const
std
::
vector
<
int64_t
>&
block_config
=
{
6
,
12
,
36
,
24
},
...
...
torchvision/csrc/models/googlenet.h
View file @
6b071be9
...
@@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
...
@@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
BatchNorm2d
bn
{
nullptr
};
torch
::
nn
::
BatchNorm2d
bn
{
nullptr
};
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
);
explicit
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
};
};
...
@@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module {
...
@@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module {
torch
::
nn
::
Dropout
dropout
{
nullptr
};
torch
::
nn
::
Dropout
dropout
{
nullptr
};
torch
::
nn
::
Linear
fc
{
nullptr
};
torch
::
nn
::
Linear
fc
{
nullptr
};
GoogLeNetImpl
(
explicit
GoogLeNetImpl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
aux_logits
=
true
,
bool
aux_logits
=
true
,
bool
transform_input
=
false
,
bool
transform_input
=
false
,
...
...
torchvision/csrc/models/inception.h
View file @
6b071be9
...
@@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
...
@@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
BatchNorm2d
bn
{
nullptr
};
torch
::
nn
::
BatchNorm2d
bn
{
nullptr
};
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
,
double
std_dev
=
0.1
);
explicit
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
,
double
std_dev
=
0.1
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
};
};
...
@@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module {
...
@@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module {
struct
VISION_API
InceptionBImpl
:
torch
::
nn
::
Module
{
struct
VISION_API
InceptionBImpl
:
torch
::
nn
::
Module
{
BasicConv2d
branch3x3
,
branch3x3dbl_1
,
branch3x3dbl_2
,
branch3x3dbl_3
;
BasicConv2d
branch3x3
,
branch3x3dbl_1
,
branch3x3dbl_2
,
branch3x3dbl_3
;
InceptionBImpl
(
int64_t
in_channels
);
explicit
InceptionBImpl
(
int64_t
in_channels
);
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
x
);
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
x
);
};
};
...
@@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module {
...
@@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d
branch3x3_1
,
branch3x3_2
,
branch7x7x3_1
,
branch7x7x3_2
,
BasicConv2d
branch3x3_1
,
branch3x3_2
,
branch7x7x3_1
,
branch7x7x3_2
,
branch7x7x3_3
,
branch7x7x3_4
;
branch7x7x3_3
,
branch7x7x3_4
;
InceptionDImpl
(
int64_t
in_channels
);
explicit
InceptionDImpl
(
int64_t
in_channels
);
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
x
);
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
x
);
};
};
...
@@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module {
...
@@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module {
branch3x3dbl_1
,
branch3x3dbl_2
,
branch3x3dbl_3a
,
branch3x3dbl_3b
,
branch3x3dbl_1
,
branch3x3dbl_2
,
branch3x3dbl_3a
,
branch3x3dbl_3b
,
branch_pool
;
branch_pool
;
InceptionEImpl
(
int64_t
in_channels
);
explicit
InceptionEImpl
(
int64_t
in_channels
);
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
x
);
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
x
);
};
};
...
@@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module {
...
@@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module {
_inceptionimpl
::
InceptionAux
AuxLogits
{
nullptr
};
_inceptionimpl
::
InceptionAux
AuxLogits
{
nullptr
};
InceptionV3Impl
(
explicit
InceptionV3Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
aux_logits
=
true
,
bool
aux_logits
=
true
,
bool
transform_input
=
false
);
bool
transform_input
=
false
);
...
...
torchvision/csrc/models/mnasnet.h
View file @
6b071be9
...
@@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module {
...
@@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module {
void
_initialize_weights
();
void
_initialize_weights
();
MNASNetImpl
(
double
alpha
,
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
explicit
MNASNetImpl
(
double
alpha
,
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
};
};
struct
MNASNet0_5Impl
:
MNASNetImpl
{
struct
MNASNet0_5Impl
:
MNASNetImpl
{
MNASNet0_5Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
explicit
MNASNet0_5Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
};
};
struct
MNASNet0_75Impl
:
MNASNetImpl
{
struct
MNASNet0_75Impl
:
MNASNetImpl
{
MNASNet0_75Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
explicit
MNASNet0_75Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
};
};
struct
MNASNet1_0Impl
:
MNASNetImpl
{
struct
MNASNet1_0Impl
:
MNASNetImpl
{
MNASNet1_0Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
explicit
MNASNet1_0Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
};
};
struct
MNASNet1_3Impl
:
MNASNetImpl
{
struct
MNASNet1_3Impl
:
MNASNetImpl
{
MNASNet1_3Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
explicit
MNASNet1_3Impl
(
int64_t
num_classes
=
1000
,
double
dropout
=
.2
);
};
};
TORCH_MODULE
(
MNASNet
);
TORCH_MODULE
(
MNASNet
);
...
...
torchvision/csrc/models/mobilenet.h
View file @
6b071be9
...
@@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
...
@@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t
last_channel
;
int64_t
last_channel
;
torch
::
nn
::
Sequential
features
,
classifier
;
torch
::
nn
::
Sequential
features
,
classifier
;
MobileNetV2Impl
(
explicit
MobileNetV2Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
double
width_mult
=
1.0
,
double
width_mult
=
1.0
,
std
::
vector
<
std
::
vector
<
int64_t
>>
inverted_residual_settings
=
{},
std
::
vector
<
std
::
vector
<
int64_t
>>
inverted_residual_settings
=
{},
...
...
torchvision/csrc/models/resnet.h
View file @
6b071be9
...
@@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module {
...
@@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module {
int64_t
blocks
,
int64_t
blocks
,
int64_t
stride
=
1
);
int64_t
stride
=
1
);
ResNetImpl
(
explicit
ResNetImpl
(
const
std
::
vector
<
int
>&
layers
,
const
std
::
vector
<
int
>&
layers
,
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
,
bool
zero_init_residual
=
false
,
...
@@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
...
@@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
}
}
struct
VISION_API
ResNet18Impl
:
ResNetImpl
<
_resnetimpl
::
BasicBlock
>
{
struct
VISION_API
ResNet18Impl
:
ResNetImpl
<
_resnetimpl
::
BasicBlock
>
{
ResNet18Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
explicit
ResNet18Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
ResNet34Impl
:
ResNetImpl
<
_resnetimpl
::
BasicBlock
>
{
struct
VISION_API
ResNet34Impl
:
ResNetImpl
<
_resnetimpl
::
BasicBlock
>
{
ResNet34Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
explicit
ResNet34Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
ResNet50Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
ResNet50Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
ResNet50Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
explicit
ResNet50Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
ResNet101Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
ResNet101Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
ResNet101Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
explicit
ResNet101Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
ResNet152Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
ResNet152Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
ResNet152Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
explicit
ResNet152Impl
(
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
ResNext50_32x4dImpl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
ResNext50_32x4dImpl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
ResNext50_32x4dImpl
(
explicit
ResNext50_32x4dImpl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
ResNext101_32x8dImpl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
ResNext101_32x8dImpl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
ResNext101_32x8dImpl
(
explicit
ResNext101_32x8dImpl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
WideResNet50_2Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
WideResNet50_2Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
WideResNet50_2Impl
(
explicit
WideResNet50_2Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
bool
zero_init_residual
=
false
);
};
};
struct
VISION_API
WideResNet101_2Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
struct
VISION_API
WideResNet101_2Impl
:
ResNetImpl
<
_resnetimpl
::
Bottleneck
>
{
WideResNet101_2Impl
(
explicit
WideResNet101_2Impl
(
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
zero_init_residual
=
false
);
bool
zero_init_residual
=
false
);
};
};
...
...
torchvision/csrc/models/shufflenetv2.h
View file @
6b071be9
...
@@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
...
@@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
};
};
struct
VISION_API
ShuffleNetV2_x0_5Impl
:
ShuffleNetV2Impl
{
struct
VISION_API
ShuffleNetV2_x0_5Impl
:
ShuffleNetV2Impl
{
ShuffleNetV2_x0_5Impl
(
int64_t
num_classes
=
1000
);
explicit
ShuffleNetV2_x0_5Impl
(
int64_t
num_classes
=
1000
);
};
};
struct
VISION_API
ShuffleNetV2_x1_0Impl
:
ShuffleNetV2Impl
{
struct
VISION_API
ShuffleNetV2_x1_0Impl
:
ShuffleNetV2Impl
{
ShuffleNetV2_x1_0Impl
(
int64_t
num_classes
=
1000
);
explicit
ShuffleNetV2_x1_0Impl
(
int64_t
num_classes
=
1000
);
};
};
struct
VISION_API
ShuffleNetV2_x1_5Impl
:
ShuffleNetV2Impl
{
struct
VISION_API
ShuffleNetV2_x1_5Impl
:
ShuffleNetV2Impl
{
ShuffleNetV2_x1_5Impl
(
int64_t
num_classes
=
1000
);
explicit
ShuffleNetV2_x1_5Impl
(
int64_t
num_classes
=
1000
);
};
};
struct
VISION_API
ShuffleNetV2_x2_0Impl
:
ShuffleNetV2Impl
{
struct
VISION_API
ShuffleNetV2_x2_0Impl
:
ShuffleNetV2Impl
{
ShuffleNetV2_x2_0Impl
(
int64_t
num_classes
=
1000
);
explicit
ShuffleNetV2_x2_0Impl
(
int64_t
num_classes
=
1000
);
};
};
TORCH_MODULE
(
ShuffleNetV2
);
TORCH_MODULE
(
ShuffleNetV2
);
...
...
torchvision/csrc/models/squeezenet.h
View file @
6b071be9
...
@@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
...
@@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
int64_t
num_classes
;
int64_t
num_classes
;
torch
::
nn
::
Sequential
features
{
nullptr
},
classifier
{
nullptr
};
torch
::
nn
::
Sequential
features
{
nullptr
},
classifier
{
nullptr
};
SqueezeNetImpl
(
double
version
=
1.0
,
int64_t
num_classes
=
1000
);
explicit
SqueezeNetImpl
(
double
version
=
1.0
,
int64_t
num_classes
=
1000
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
torch
::
Tensor
forward
(
torch
::
Tensor
x
);
};
};
...
@@ -19,7 +19,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
...
@@ -19,7 +19,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
// accuracy with 50x fewer parameters and <0.5MB model size"
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
// <https://arxiv.org/abs/1602.07360> paper.
struct
VISION_API
SqueezeNet1_0Impl
:
SqueezeNetImpl
{
struct
VISION_API
SqueezeNet1_0Impl
:
SqueezeNetImpl
{
SqueezeNet1_0Impl
(
int64_t
num_classes
=
1000
);
explicit
SqueezeNet1_0Impl
(
int64_t
num_classes
=
1000
);
};
};
// SqueezeNet 1.1 model from the official SqueezeNet repo
// SqueezeNet 1.1 model from the official SqueezeNet repo
...
@@ -27,7 +27,7 @@ struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
...
@@ -27,7 +27,7 @@ struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
// than SqueezeNet 1.0, without sacrificing accuracy.
struct
VISION_API
SqueezeNet1_1Impl
:
SqueezeNetImpl
{
struct
VISION_API
SqueezeNet1_1Impl
:
SqueezeNetImpl
{
SqueezeNet1_1Impl
(
int64_t
num_classes
=
1000
);
explicit
SqueezeNet1_1Impl
(
int64_t
num_classes
=
1000
);
};
};
TORCH_MODULE
(
SqueezeNet
);
TORCH_MODULE
(
SqueezeNet
);
...
...
torchvision/csrc/models/vgg.h
View file @
6b071be9
...
@@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module {
...
@@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module {
void
_initialize_weights
();
void
_initialize_weights
();
VGGImpl
(
explicit
VGGImpl
(
const
torch
::
nn
::
Sequential
&
features
,
const
torch
::
nn
::
Sequential
&
features
,
int64_t
num_classes
=
1000
,
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
bool
initialize_weights
=
true
);
...
@@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module {
...
@@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module {
// VGG 11-layer model (configuration "A")
// VGG 11-layer model (configuration "A")
struct
VISION_API
VGG11Impl
:
VGGImpl
{
struct
VISION_API
VGG11Impl
:
VGGImpl
{
VGG11Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG11Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 13-layer model (configuration "B")
// VGG 13-layer model (configuration "B")
struct
VISION_API
VGG13Impl
:
VGGImpl
{
struct
VISION_API
VGG13Impl
:
VGGImpl
{
VGG13Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG13Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 16-layer model (configuration "D")
// VGG 16-layer model (configuration "D")
struct
VISION_API
VGG16Impl
:
VGGImpl
{
struct
VISION_API
VGG16Impl
:
VGGImpl
{
VGG16Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG16Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 19-layer model (configuration "E")
// VGG 19-layer model (configuration "E")
struct
VISION_API
VGG19Impl
:
VGGImpl
{
struct
VISION_API
VGG19Impl
:
VGGImpl
{
VGG19Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG19Impl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 11-layer model (configuration "A") with batch normalization
// VGG 11-layer model (configuration "A") with batch normalization
struct
VISION_API
VGG11BNImpl
:
VGGImpl
{
struct
VISION_API
VGG11BNImpl
:
VGGImpl
{
VGG11BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG11BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 13-layer model (configuration "B") with batch normalization
// VGG 13-layer model (configuration "B") with batch normalization
struct
VISION_API
VGG13BNImpl
:
VGGImpl
{
struct
VISION_API
VGG13BNImpl
:
VGGImpl
{
VGG13BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG13BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 16-layer model (configuration "D") with batch normalization
// VGG 16-layer model (configuration "D") with batch normalization
struct
VISION_API
VGG16BNImpl
:
VGGImpl
{
struct
VISION_API
VGG16BNImpl
:
VGGImpl
{
VGG16BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG16BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
// VGG 19-layer model (configuration 'E') with batch normalization
// VGG 19-layer model (configuration 'E') with batch normalization
struct
VISION_API
VGG19BNImpl
:
VGGImpl
{
struct
VISION_API
VGG19BNImpl
:
VGGImpl
{
VGG19BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
explicit
VGG19BNImpl
(
int64_t
num_classes
=
1000
,
bool
initialize_weights
=
true
);
};
};
TORCH_MODULE
(
VGG
);
TORCH_MODULE
(
VGG
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment