Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
601ce5fc
Unverified
Commit
601ce5fc
authored
Mar 12, 2020
by
Francisco Massa
Committed by
GitHub
Mar 12, 2020
Browse files
Fix C++ linnt (#1971)
parent
12be107b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
11 deletions
+20
-11
torchvision/csrc/ROIAlign.h
torchvision/csrc/ROIAlign.h
+14
-3
torchvision/csrc/models/mnasnet.cpp
torchvision/csrc/models/mnasnet.cpp
+1
-4
torchvision/csrc/models/mobilenet.cpp
torchvision/csrc/models/mobilenet.cpp
+3
-3
torchvision/csrc/models/vgg.cpp
torchvision/csrc/models/vgg.cpp
+2
-1
No files found.
torchvision/csrc/ROIAlign.h
View file @
601ce5fc
...
...
@@ -36,7 +36,13 @@ at::Tensor ROIAlign_forward(
#endif
}
return
ROIAlign_forward_cpu
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
);
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
);
}
at
::
Tensor
ROIAlign_backward
(
...
...
@@ -137,8 +143,13 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
input_shape
[
3
],
ctx
->
saved_data
[
"sampling_ratio"
].
toInt
(),
ctx
->
saved_data
[
"aligned"
].
toBool
());
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
(),
Variable
(),
Variable
()};
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
};
...
...
torchvision/csrc/models/mnasnet.cpp
View file @
601ce5fc
...
...
@@ -107,10 +107,7 @@ void MNASNetImpl::_initialize_weights() {
for
(
auto
&
module
:
modules
(
/*include_self=*/
false
))
{
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
Conv2dImpl
*>
(
module
.
get
()))
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
0
,
torch
::
kFanOut
,
torch
::
kReLU
);
M
->
weight
,
0
,
torch
::
kFanOut
,
torch
::
kReLU
);
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm2dImpl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
...
...
torchvision/csrc/models/mobilenet.cpp
View file @
601ce5fc
...
...
@@ -134,11 +134,11 @@ MobileNetV2Impl::MobileNetV2Impl(
for
(
auto
&
module
:
modules
(
/*include_self=*/
false
))
{
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
Conv2dImpl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
0
,
torch
::
kFanOut
);
torch
::
nn
::
init
::
kaiming_normal_
(
M
->
weight
,
0
,
torch
::
kFanOut
);
if
(
M
->
options
.
bias
())
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm2dImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm2dImpl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
ones_
(
M
->
weight
);
torch
::
nn
::
init
::
zeros_
(
M
->
bias
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
...
...
torchvision/csrc/models/vgg.cpp
View file @
601ce5fc
...
...
@@ -38,7 +38,8 @@ void VGGImpl::_initialize_weights() {
torch
::
kFanOut
,
torch
::
kReLU
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm2dImpl
*>
(
module
.
get
()))
{
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
BatchNorm2dImpl
*>
(
module
.
get
()))
{
torch
::
nn
::
init
::
constant_
(
M
->
weight
,
1
);
torch
::
nn
::
init
::
constant_
(
M
->
bias
,
0
);
}
else
if
(
auto
M
=
dynamic_cast
<
torch
::
nn
::
LinearImpl
*>
(
module
.
get
()))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment