Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b6f28ec1
Unverified
Commit
b6f28ec1
authored
Mar 04, 2020
by
Francis Charette Migneault
Committed by
GitHub
Mar 04, 2020
Browse files
replace torch 1.5.0 items flagged with deprecation warnings (fix #1906) (#1918)
parent
6aa99ced
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
48 additions
and
48 deletions
+48
-48
torchvision/csrc/models/densenet.cpp
torchvision/csrc/models/densenet.cpp
+6
-6
torchvision/csrc/models/googlenet.cpp
torchvision/csrc/models/googlenet.cpp
+2
-2
torchvision/csrc/models/googlenet.h
torchvision/csrc/models/googlenet.h
+1
-1
torchvision/csrc/models/inception.cpp
torchvision/csrc/models/inception.cpp
+1
-1
torchvision/csrc/models/inception.h
torchvision/csrc/models/inception.h
+1
-1
torchvision/csrc/models/mnasnet.cpp
torchvision/csrc/models/mnasnet.cpp
+10
-10
torchvision/csrc/models/mobilenet.cpp
torchvision/csrc/models/mobilenet.cpp
+4
-4
torchvision/csrc/models/resnet.cpp
torchvision/csrc/models/resnet.cpp
+5
-5
torchvision/csrc/models/resnet.h
torchvision/csrc/models/resnet.h
+7
-7
torchvision/csrc/models/shufflenetv2.cpp
torchvision/csrc/models/shufflenetv2.cpp
+7
-7
torchvision/csrc/models/vgg.cpp
torchvision/csrc/models/vgg.cpp
+4
-4
No files found.
torchvision/csrc/models/densenet.cpp
View file @
b6f28ec1
...
@@ -15,14 +15,14 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
...
@@ -15,14 +15,14 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
int64_t
bn_size
,
int64_t
bn_size
,
double
drop_rate
)
double
drop_rate
)
:
drop_rate
(
drop_rate
)
{
:
drop_rate
(
drop_rate
)
{
push_back
(
"norm1"
,
torch
::
nn
::
BatchNorm
(
num_input_features
));
push_back
(
"norm1"
,
torch
::
nn
::
BatchNorm
2d
(
num_input_features
));
push_back
(
"relu1"
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
push_back
(
"relu1"
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
push_back
(
push_back
(
"conv1"
,
"conv1"
,
torch
::
nn
::
Conv2d
(
Options
(
num_input_features
,
bn_size
*
growth_rate
,
1
)
torch
::
nn
::
Conv2d
(
Options
(
num_input_features
,
bn_size
*
growth_rate
,
1
)
.
stride
(
1
)
.
stride
(
1
)
.
bias
(
false
)));
.
bias
(
false
)));
push_back
(
"norm2"
,
torch
::
nn
::
BatchNorm
(
bn_size
*
growth_rate
));
push_back
(
"norm2"
,
torch
::
nn
::
BatchNorm
2d
(
bn_size
*
growth_rate
));
push_back
(
"relu2"
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
push_back
(
"relu2"
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
push_back
(
push_back
(
"conv2"
,
"conv2"
,
...
@@ -69,7 +69,7 @@ TORCH_MODULE(_DenseBlock);
...
@@ -69,7 +69,7 @@ TORCH_MODULE(_DenseBlock);
struct
_TransitionImpl
:
torch
::
nn
::
SequentialImpl
{
struct
_TransitionImpl
:
torch
::
nn
::
SequentialImpl
{
_TransitionImpl
(
int64_t
num_input_features
,
int64_t
num_output_features
)
{
_TransitionImpl
(
int64_t
num_input_features
,
int64_t
num_output_features
)
{
push_back
(
"norm"
,
torch
::
nn
::
BatchNorm
(
num_input_features
));
push_back
(
"norm"
,
torch
::
nn
::
BatchNorm
2d
(
num_input_features
));
push_back
(
"relu "
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
push_back
(
"relu "
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
push_back
(
push_back
(
"conv"
,
"conv"
,
...
@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
...
@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
torch
::
nn
::
Conv2d
(
torch
::
nn
::
Conv2d
(
Options
(
3
,
num_init_features
,
7
).
stride
(
2
).
padding
(
3
).
bias
(
false
)));
Options
(
3
,
num_init_features
,
7
).
stride
(
2
).
padding
(
3
).
bias
(
false
)));
features
->
push_back
(
"norm0"
,
torch
::
nn
::
BatchNorm
(
num_init_features
));
features
->
push_back
(
"norm0"
,
torch
::
nn
::
BatchNorm
2d
(
num_init_features
));
features
->
push_back
(
"relu0"
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
features
->
push_back
(
"relu0"
,
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
features
->
push_back
(
features
->
push_back
(
"pool0"
,
torch
::
nn
::
Functional
(
torch
::
max_pool2d
,
3
,
2
,
1
,
1
,
false
));
"pool0"
,
torch
::
nn
::
Functional
(
torch
::
max_pool2d
,
3
,
2
,
1
,
1
,
false
));
...
@@ -125,7 +125,7 @@ DenseNetImpl::DenseNetImpl(
...
@@ -125,7 +125,7 @@ DenseNetImpl::DenseNetImpl(
}
}
// Final batch norm
// Final batch norm
features
->
push_back
(
"norm5"
,
torch
::
nn
::
BatchNorm
(
num_features
));
features
->
push_back
(
"norm5"
,
torch
::
nn
::
BatchNorm
2d
(
num_features
));
// Linear layer
// Linear layer
classifier
=
torch
::
nn
::
Linear
(
num_features
,
num_classes
);
classifier
=
torch
::
nn
::
Linear
(
num_features
,
num_classes
);
...
@@ -136,7 +136,7 @@ DenseNetImpl::DenseNetImpl(
...
@@ -136,7 +136,7 @@ DenseNetImpl::DenseNetImpl(
for
(
auto
&
module
:
modules
(
/*include_self=*/
false
))
{
for
(
auto
&
module
:
modules
(
/*include_self=*/
false
))
{
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
Conv2dImpl
*>
(
module
.
get
()))
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
Conv2dImpl
*>
(
module
.
get
()))
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
);
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
);
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNormImpl
*>
(
module
.
get
()))
{
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm
2d
Impl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
...
...
torchvision/csrc/models/googlenet.cpp
View file @
b6f28ec1
...
@@ -11,7 +11,7 @@ namespace _googlenetimpl {
...
@@ -11,7 +11,7 @@ namespace _googlenetimpl {
BasicConv2dImpl
::
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
)
{
BasicConv2dImpl
::
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
)
{
options
.
bias
(
false
);
options
.
bias
(
false
);
conv
=
torch
::
nn
::
Conv2d
(
options
);
conv
=
torch
::
nn
::
Conv2d
(
options
);
bn
=
torch
::
nn
::
BatchNorm
(
bn
=
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
options
.
out_channels
()).
eps
(
0.001
));
torch
::
nn
::
BatchNormOptions
(
options
.
out_channels
()).
eps
(
0.001
));
register_module
(
"conv"
,
conv
);
register_module
(
"conv"
,
conv
);
...
@@ -155,7 +155,7 @@ void GoogLeNetImpl::_initialize_weights() {
...
@@ -155,7 +155,7 @@ void GoogLeNetImpl::_initialize_weights() {
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
torch
::
nn
::
init
::
normal_
(
M
->
weight
);
// Note: used instead of truncated
torch
::
nn
::
init
::
normal_
(
M
->
weight
);
// Note: used instead of truncated
// normal initialization
// normal initialization
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNormImpl
*>
(
module
.
get
()))
{
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm
2d
Impl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
}
}
...
...
torchvision/csrc/models/googlenet.h
View file @
b6f28ec1
...
@@ -10,7 +10,7 @@ namespace models {
...
@@ -10,7 +10,7 @@ namespace models {
namespace
_googlenetimpl
{
namespace
_googlenetimpl
{
struct
VISION_API
BasicConv2dImpl
:
torch
::
nn
::
Module
{
struct
VISION_API
BasicConv2dImpl
:
torch
::
nn
::
Module
{
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
BatchNorm
bn
{
nullptr
};
torch
::
nn
::
BatchNorm
2d
bn
{
nullptr
};
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
);
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
);
...
...
torchvision/csrc/models/inception.cpp
View file @
b6f28ec1
...
@@ -11,7 +11,7 @@ BasicConv2dImpl::BasicConv2dImpl(
...
@@ -11,7 +11,7 @@ BasicConv2dImpl::BasicConv2dImpl(
double
std_dev
)
{
double
std_dev
)
{
options
.
bias
(
false
);
options
.
bias
(
false
);
conv
=
torch
::
nn
::
Conv2d
(
options
);
conv
=
torch
::
nn
::
Conv2d
(
options
);
bn
=
torch
::
nn
::
BatchNorm
(
bn
=
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
options
.
out_channels
()).
eps
(
0.001
));
torch
::
nn
::
BatchNormOptions
(
options
.
out_channels
()).
eps
(
0.001
));
register_module
(
"conv"
,
conv
);
register_module
(
"conv"
,
conv
);
...
...
torchvision/csrc/models/inception.h
View file @
b6f28ec1
...
@@ -9,7 +9,7 @@ namespace models {
...
@@ -9,7 +9,7 @@ namespace models {
namespace
_inceptionimpl
{
namespace
_inceptionimpl
{
struct
VISION_API
BasicConv2dImpl
:
torch
::
nn
::
Module
{
struct
VISION_API
BasicConv2dImpl
:
torch
::
nn
::
Module
{
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
Conv2d
conv
{
nullptr
};
torch
::
nn
::
BatchNorm
bn
{
nullptr
};
torch
::
nn
::
BatchNorm
2d
bn
{
nullptr
};
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
,
double
std_dev
=
0.1
);
BasicConv2dImpl
(
torch
::
nn
::
Conv2dOptions
options
,
double
std_dev
=
0.1
);
...
...
torchvision/csrc/models/mnasnet.cpp
View file @
b6f28ec1
...
@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
...
@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
apply_residual
=
input
==
output
&&
stride
==
1
;
apply_residual
=
input
==
output
&&
stride
==
1
;
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
input
,
mid
,
1
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
input
,
mid
,
1
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
mid
).
momentum
(
bn_momentum
)));
torch
::
nn
::
BatchNormOptions
(
mid
).
momentum
(
bn_momentum
)));
layers
->
push_back
(
layers
->
push_back
(
torch
::
nn
::
Functional
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
)));
torch
::
nn
::
Functional
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
)));
...
@@ -34,12 +34,12 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
...
@@ -34,12 +34,12 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
.
stride
(
stride
)
.
stride
(
stride
)
.
groups
(
mid
)
.
groups
(
mid
)
.
bias
(
false
))));
.
bias
(
false
))));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
mid
).
momentum
(
bn_momentum
)));
torch
::
nn
::
BatchNormOptions
(
mid
).
momentum
(
bn_momentum
)));
layers
->
push_back
(
layers
->
push_back
(
torch
::
nn
::
Functional
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
)));
torch
::
nn
::
Functional
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
)));
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
mid
,
output
,
1
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
mid
,
output
,
1
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
output
).
momentum
(
bn_momentum
)));
torch
::
nn
::
BatchNormOptions
(
output
).
momentum
(
bn_momentum
)));
register_module
(
"layers"
,
layers
);
register_module
(
"layers"
,
layers
);
...
@@ -109,9 +109,9 @@ void MNASNetImpl::_initialize_weights() {
...
@@ -109,9 +109,9 @@ void MNASNetImpl::_initialize_weights() {
torch
::
nn
::
init
::
kaiming_normal_
(
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
M
->
weight
,
0
,
0
,
torch
::
nn
::
init
::
FanMode
::
FanOut
,
torch
::
k
FanOut
,
torch
::
nn
::
init
::
Nonlinearity
::
ReLU
);
torch
::
k
ReLU
);
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNormImpl
*>
(
module
.
get
()))
{
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm
2d
Impl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
...
@@ -128,17 +128,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
...
@@ -128,17 +128,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
layers
->
push_back
(
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
3
,
32
,
3
).
padding
(
1
).
stride
(
2
).
bias
(
false
)));
torch
::
nn
::
Conv2d
(
Options
(
3
,
32
,
3
).
padding
(
1
).
stride
(
2
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
32
).
momentum
(
BN_MOMENTUM
)));
torch
::
nn
::
BatchNormOptions
(
32
).
momentum
(
BN_MOMENTUM
)));
layers
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
layers
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
32
,
32
,
3
).
padding
(
1
).
stride
(
1
).
groups
(
32
).
bias
(
false
)));
Options
(
32
,
32
,
3
).
padding
(
1
).
stride
(
1
).
groups
(
32
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
32
).
momentum
(
BN_MOMENTUM
)));
torch
::
nn
::
BatchNormOptions
(
32
).
momentum
(
BN_MOMENTUM
)));
layers
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
layers
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
layers
->
push_back
(
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
32
,
16
,
1
).
padding
(
0
).
stride
(
1
).
bias
(
false
)));
torch
::
nn
::
Conv2d
(
Options
(
32
,
16
,
1
).
padding
(
0
).
stride
(
1
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
16
).
momentum
(
BN_MOMENTUM
)));
torch
::
nn
::
BatchNormOptions
(
16
).
momentum
(
BN_MOMENTUM
)));
layers
->
push_back
(
stack
(
16
,
depths
[
0
],
3
,
2
,
3
,
3
,
BN_MOMENTUM
));
layers
->
push_back
(
stack
(
16
,
depths
[
0
],
3
,
2
,
3
,
3
,
BN_MOMENTUM
));
...
@@ -150,7 +150,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
...
@@ -150,7 +150,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
layers
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
depths
[
5
],
1280
,
1
).
padding
(
0
).
stride
(
1
).
bias
(
false
)));
Options
(
depths
[
5
],
1280
,
1
).
padding
(
0
).
stride
(
1
).
bias
(
false
)));
layers
->
push_back
(
torch
::
nn
::
BatchNorm
(
layers
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
torch
::
nn
::
BatchNormOptions
(
1280
).
momentum
(
BN_MOMENTUM
)));
torch
::
nn
::
BatchNormOptions
(
1280
).
momentum
(
BN_MOMENTUM
)));
layers
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
layers
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
...
...
torchvision/csrc/models/mobilenet.cpp
View file @
b6f28ec1
...
@@ -33,7 +33,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
...
@@ -33,7 +33,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
.
padding
(
padding
)
.
padding
(
padding
)
.
groups
(
groups
)
.
groups
(
groups
)
.
bias
(
false
)));
.
bias
(
false
)));
push_back
(
torch
::
nn
::
BatchNorm
(
out_planes
));
push_back
(
torch
::
nn
::
BatchNorm
2d
(
out_planes
));
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu6_
));
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu6_
));
}
}
...
@@ -68,7 +68,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
...
@@ -68,7 +68,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
conv
->
push_back
(
ConvBNReLU
(
hidden_dim
,
hidden_dim
,
3
,
stride
,
hidden_dim
));
conv
->
push_back
(
ConvBNReLU
(
hidden_dim
,
hidden_dim
,
3
,
stride
,
hidden_dim
));
conv
->
push_back
(
torch
::
nn
::
Conv2d
(
conv
->
push_back
(
torch
::
nn
::
Conv2d
(
Options
(
hidden_dim
,
output
,
1
).
stride
(
1
).
padding
(
0
).
bias
(
false
)));
Options
(
hidden_dim
,
output
,
1
).
stride
(
1
).
padding
(
0
).
bias
(
false
)));
conv
->
push_back
(
torch
::
nn
::
BatchNorm
(
output
));
conv
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
output
));
register_module
(
"conv"
,
conv
);
register_module
(
"conv"
,
conv
);
}
}
...
@@ -135,10 +135,10 @@ MobileNetV2Impl::MobileNetV2Impl(
...
@@ -135,10 +135,10 @@ MobileNetV2Impl::MobileNetV2Impl(
for
(
auto
&
module
:
modules
(
/*include_self=*/
false
))
{
for
(
auto
&
module
:
modules
(
/*include_self=*/
false
))
{
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
Conv2dImpl
*>
(
module
.
get
()))
{
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
Conv2dImpl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
kaiming_normal_
(
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
0
,
torch
::
nn
::
init
::
FanMode
::
FanOut
);
M
->
weight
,
0
,
torch
::
k
FanOut
);
if
(
M
->
options
.
bias
())
if
(
M
->
options
.
bias
())
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNormImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm
2d
Impl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
...
...
torchvision/csrc/models/resnet.cpp
View file @
b6f28ec1
...
@@ -40,8 +40,8 @@ BasicBlock::BasicBlock(
...
@@ -40,8 +40,8 @@ BasicBlock::BasicBlock(
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
);
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
);
conv2
=
conv3x3
(
planes
,
planes
);
conv2
=
conv3x3
(
planes
,
planes
);
bn1
=
torch
::
nn
::
BatchNorm
(
planes
);
bn1
=
torch
::
nn
::
BatchNorm
2d
(
planes
);
bn2
=
torch
::
nn
::
BatchNorm
(
planes
);
bn2
=
torch
::
nn
::
BatchNorm
2d
(
planes
);
register_module
(
"conv1"
,
conv1
);
register_module
(
"conv1"
,
conv1
);
register_module
(
"conv2"
,
conv2
);
register_module
(
"conv2"
,
conv2
);
...
@@ -68,9 +68,9 @@ Bottleneck::Bottleneck(
...
@@ -68,9 +68,9 @@ Bottleneck::Bottleneck(
conv2
=
conv3x3
(
width
,
width
,
stride
,
groups
);
conv2
=
conv3x3
(
width
,
width
,
stride
,
groups
);
conv3
=
conv1x1
(
width
,
planes
*
expansion
);
conv3
=
conv1x1
(
width
,
planes
*
expansion
);
bn1
=
torch
::
nn
::
BatchNorm
(
width
);
bn1
=
torch
::
nn
::
BatchNorm
2d
(
width
);
bn2
=
torch
::
nn
::
BatchNorm
(
width
);
bn2
=
torch
::
nn
::
BatchNorm
2d
(
width
);
bn3
=
torch
::
nn
::
BatchNorm
(
planes
*
expansion
);
bn3
=
torch
::
nn
::
BatchNorm
2d
(
planes
*
expansion
);
register_module
(
"conv1"
,
conv1
);
register_module
(
"conv1"
,
conv1
);
register_module
(
"conv2"
,
conv2
);
register_module
(
"conv2"
,
conv2
);
...
...
torchvision/csrc/models/resnet.h
View file @
b6f28ec1
...
@@ -28,7 +28,7 @@ struct VISION_API BasicBlock : torch::nn::Module {
...
@@ -28,7 +28,7 @@ struct VISION_API BasicBlock : torch::nn::Module {
torch
::
nn
::
Sequential
downsample
;
torch
::
nn
::
Sequential
downsample
;
torch
::
nn
::
Conv2d
conv1
{
nullptr
},
conv2
{
nullptr
};
torch
::
nn
::
Conv2d
conv1
{
nullptr
},
conv2
{
nullptr
};
torch
::
nn
::
BatchNorm
bn1
{
nullptr
},
bn2
{
nullptr
};
torch
::
nn
::
BatchNorm
2d
bn1
{
nullptr
},
bn2
{
nullptr
};
static
int
expansion
;
static
int
expansion
;
...
@@ -51,7 +51,7 @@ struct VISION_API Bottleneck : torch::nn::Module {
...
@@ -51,7 +51,7 @@ struct VISION_API Bottleneck : torch::nn::Module {
torch
::
nn
::
Sequential
downsample
;
torch
::
nn
::
Sequential
downsample
;
torch
::
nn
::
Conv2d
conv1
{
nullptr
},
conv2
{
nullptr
},
conv3
{
nullptr
};
torch
::
nn
::
Conv2d
conv1
{
nullptr
},
conv2
{
nullptr
},
conv3
{
nullptr
};
torch
::
nn
::
BatchNorm
bn1
{
nullptr
},
bn2
{
nullptr
},
bn3
{
nullptr
};
torch
::
nn
::
BatchNorm
2d
bn1
{
nullptr
},
bn2
{
nullptr
},
bn3
{
nullptr
};
static
int
expansion
;
static
int
expansion
;
...
@@ -71,7 +71,7 @@ template <typename Block>
...
@@ -71,7 +71,7 @@ template <typename Block>
struct
ResNetImpl
:
torch
::
nn
::
Module
{
struct
ResNetImpl
:
torch
::
nn
::
Module
{
int64_t
groups
,
base_width
,
inplanes
;
int64_t
groups
,
base_width
,
inplanes
;
torch
::
nn
::
Conv2d
conv1
;
torch
::
nn
::
Conv2d
conv1
;
torch
::
nn
::
BatchNorm
bn1
;
torch
::
nn
::
BatchNorm
2d
bn1
;
torch
::
nn
::
Sequential
layer1
,
layer2
,
layer3
,
layer4
;
torch
::
nn
::
Sequential
layer1
,
layer2
,
layer3
,
layer4
;
torch
::
nn
::
Linear
fc
;
torch
::
nn
::
Linear
fc
;
...
@@ -99,7 +99,7 @@ torch::nn::Sequential ResNetImpl<Block>::_make_layer(
...
@@ -99,7 +99,7 @@ torch::nn::Sequential ResNetImpl<Block>::_make_layer(
if
(
stride
!=
1
||
inplanes
!=
planes
*
Block
::
expansion
)
{
if
(
stride
!=
1
||
inplanes
!=
planes
*
Block
::
expansion
)
{
downsample
=
torch
::
nn
::
Sequential
(
downsample
=
torch
::
nn
::
Sequential
(
_resnetimpl
::
conv1x1
(
inplanes
,
planes
*
Block
::
expansion
,
stride
),
_resnetimpl
::
conv1x1
(
inplanes
,
planes
*
Block
::
expansion
,
stride
),
torch
::
nn
::
BatchNorm
(
planes
*
Block
::
expansion
));
torch
::
nn
::
BatchNorm
2d
(
planes
*
Block
::
expansion
));
}
}
torch
::
nn
::
Sequential
layers
;
torch
::
nn
::
Sequential
layers
;
...
@@ -146,9 +146,9 @@ ResNetImpl<Block>::ResNetImpl(
...
@@ -146,9 +146,9 @@ ResNetImpl<Block>::ResNetImpl(
torch
::
nn
::
init
::
kaiming_normal_
(
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
M
->
weight
,
/*a=*/
0
,
/*a=*/
0
,
torch
::
nn
::
init
::
FanMode
::
FanOut
,
torch
::
k
FanOut
,
torch
::
nn
::
init
::
Nonlinearity
::
ReLU
);
torch
::
k
ReLU
);
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNormImpl
*>
(
module
.
get
()))
{
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm
2d
Impl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
}
}
...
...
torchvision/csrc/models/shufflenetv2.cpp
View file @
b6f28ec1
...
@@ -49,20 +49,20 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
...
@@ -49,20 +49,20 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
if
(
stride
>
1
)
{
if
(
stride
>
1
)
{
branch1
=
torch
::
nn
::
Sequential
(
branch1
=
torch
::
nn
::
Sequential
(
conv33
(
inp
,
inp
,
stride
),
conv33
(
inp
,
inp
,
stride
),
torch
::
nn
::
BatchNorm
(
inp
),
torch
::
nn
::
BatchNorm
2d
(
inp
),
conv11
(
inp
,
branch_features
),
conv11
(
inp
,
branch_features
),
torch
::
nn
::
BatchNorm
(
branch_features
),
torch
::
nn
::
BatchNorm
2d
(
branch_features
),
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
}
}
branch2
=
torch
::
nn
::
Sequential
(
branch2
=
torch
::
nn
::
Sequential
(
conv11
(
stride
>
1
?
inp
:
branch_features
,
branch_features
),
conv11
(
stride
>
1
?
inp
:
branch_features
,
branch_features
),
torch
::
nn
::
BatchNorm
(
branch_features
),
torch
::
nn
::
BatchNorm
2d
(
branch_features
),
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
),
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
),
conv33
(
branch_features
,
branch_features
,
stride
),
conv33
(
branch_features
,
branch_features
,
stride
),
torch
::
nn
::
BatchNorm
(
branch_features
),
torch
::
nn
::
BatchNorm
2d
(
branch_features
),
conv11
(
branch_features
,
branch_features
),
conv11
(
branch_features
,
branch_features
),
torch
::
nn
::
BatchNorm
(
branch_features
),
torch
::
nn
::
BatchNorm
2d
(
branch_features
),
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
if
(
!
branch1
.
is_empty
())
if
(
!
branch1
.
is_empty
())
...
@@ -108,7 +108,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
...
@@ -108,7 +108,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
.
stride
(
2
)
.
stride
(
2
)
.
padding
(
1
)
.
padding
(
1
)
.
bias
(
false
)),
.
bias
(
false
)),
torch
::
nn
::
BatchNorm
(
output_channels
),
torch
::
nn
::
BatchNorm
2d
(
output_channels
),
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
input_channels
=
output_channels
;
input_channels
=
output_channels
;
...
@@ -135,7 +135,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
...
@@ -135,7 +135,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
.
stride
(
1
)
.
stride
(
1
)
.
padding
(
0
)
.
padding
(
0
)
.
bias
(
false
)),
.
bias
(
false
)),
torch
::
nn
::
BatchNorm
(
output_channels
),
torch
::
nn
::
BatchNorm
2d
(
output_channels
),
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
fc
=
torch
::
nn
::
Linear
(
output_channels
,
num_classes
);
fc
=
torch
::
nn
::
Linear
(
output_channels
,
num_classes
);
...
...
torchvision/csrc/models/vgg.cpp
View file @
b6f28ec1
...
@@ -19,7 +19,7 @@ torch::nn::Sequential makeLayers(
...
@@ -19,7 +19,7 @@ torch::nn::Sequential makeLayers(
torch
::
nn
::
Conv2dOptions
(
channels
,
V
,
3
).
padding
(
1
)));
torch
::
nn
::
Conv2dOptions
(
channels
,
V
,
3
).
padding
(
1
)));
if
(
batch_norm
)
if
(
batch_norm
)
seq
->
push_back
(
torch
::
nn
::
BatchNorm
(
V
));
seq
->
push_back
(
torch
::
nn
::
BatchNorm
2d
(
V
));
seq
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
seq
->
push_back
(
torch
::
nn
::
Functional
(
modelsimpl
::
relu_
));
channels
=
V
;
channels
=
V
;
...
@@ -35,10 +35,10 @@ void VGGImpl::_initialize_weights() {
...
@@ -35,10 +35,10 @@ void VGGImpl::_initialize_weights() {
torch
::
nn
::
init
::
kaiming_normal_
(
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
M
->
weight
,
/*a=*/
0
,
/*a=*/
0
,
torch
::
nn
::
init
::
FanMode
::
FanOut
,
torch
::
k
FanOut
,
torch
::
nn
::
init
::
Nonlinearity
::
ReLU
);
torch
::
k
ReLU
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNormImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm
2d
Impl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment