Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
fecd1385
Commit
fecd1385
authored
Jul 23, 2019
by
Shahriar
Committed by
Francisco Massa
Jul 23, 2019
Browse files
Update C++ Models to use TORCH_CHECK instead of asserts (#1144)
* Replaced asserts with TORCH_CHECK * Fixed an error
parent
737966a3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
35 deletions
+32
-35
torchvision/csrc/models/mnasnet.cpp
torchvision/csrc/models/mnasnet.cpp
+4
-4
torchvision/csrc/models/mobilenet.cpp
torchvision/csrc/models/mobilenet.cpp
+4
-5
torchvision/csrc/models/modelsimpl.h
torchvision/csrc/models/modelsimpl.h
+4
-0
torchvision/csrc/models/resnet.cpp
torchvision/csrc/models/resnet.cpp
+5
-5
torchvision/csrc/models/resnet.h
torchvision/csrc/models/resnet.h
+1
-1
torchvision/csrc/models/shufflenetv2.cpp
torchvision/csrc/models/shufflenetv2.cpp
+8
-15
torchvision/csrc/models/squeezenet.cpp
torchvision/csrc/models/squeezenet.cpp
+6
-5
No files found.
torchvision/csrc/models/mnasnet.cpp
View file @
fecd1385
...
...
@@ -17,8 +17,8 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
int64_t
stride
,
double
expansion_factor
,
double
bn_momentum
=
0.1
)
{
assert
(
stride
==
1
||
stride
==
2
);
assert
(
kernel
==
3
||
kernel
==
5
);
TORCH_CHECK
(
stride
==
1
||
stride
==
2
);
TORCH_CHECK
(
kernel
==
3
||
kernel
==
5
);
auto
mid
=
int64_t
(
input
*
expansion_factor
);
apply_residual
=
input
==
output
&&
stride
==
1
;
...
...
@@ -74,7 +74,7 @@ StackSequentail stack(
double
exp_factor
,
int64_t
repeats
,
double
bn_momentum
)
{
assert
(
repeats
>=
1
);
TORCH_CHECK
(
repeats
>=
1
);
StackSequentail
seq
;
seq
->
push_back
(
MNASNetInvertedResidual
(
...
...
@@ -91,7 +91,7 @@ int64_t round_to_multiple_of(
int64_t
val
,
int64_t
divisor
,
double
round_up_bias
=
.9
)
{
assert
(
0.0
<
round_up_bias
&&
round_up_bias
<
1.0
);
TORCH_CHECK
(
0.0
<
round_up_bias
&&
round_up_bias
<
1.0
);
auto
new_val
=
std
::
max
(
divisor
,
(
val
+
divisor
/
2
)
/
divisor
*
divisor
);
return
new_val
>=
round_up_bias
*
val
?
new_val
:
new_val
+
divisor
;
}
...
...
torchvision/csrc/models/mobilenet.cpp
View file @
fecd1385
...
...
@@ -59,7 +59,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
return
double
(
std
::
abs
(
a
-
b
))
<
std
::
numeric_limits
<
double
>::
epsilon
();
};
assert
(
stride
==
1
||
stride
==
2
);
TORCH_CHECK
(
stride
==
1
||
stride
==
2
);
auto
hidden_dim
=
int64_t
(
std
::
round
(
input
*
expand_ratio
));
if
(
!
double_compare
(
expand_ratio
,
1
))
...
...
@@ -103,10 +103,9 @@ MobileNetV2Impl::MobileNetV2Impl(
{
6
,
320
,
1
,
1
},
};
if
(
inverted_residual_settings
[
0
].
size
()
!=
4
)
{
std
::
cerr
<<
"inverted_residual_settings should contain 4-element vectors"
;
assert
(
false
);
}
TORCH_CHECK
(
inverted_residual_settings
[
0
].
size
()
==
4
,
"inverted_residual_settings should contain 4-element vectors"
);
input_channel
=
make_divisible
(
input_channel
*
width_mult
,
round_nearest
);
this
->
last_channel
=
...
...
torchvision/csrc/models/modelsimpl.h
View file @
fecd1385
...
...
@@ -3,6 +3,10 @@
#include <torch/torch.h>
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
namespace
vision
{
namespace
models
{
namespace
modelsimpl
{
...
...
torchvision/csrc/models/resnet.cpp
View file @
fecd1385
#include "resnet.h"
#include "modelsimpl.h"
namespace
vision
{
namespace
models
{
namespace
_resnetimpl
{
...
...
@@ -30,11 +32,9 @@ BasicBlock::BasicBlock(
int64_t
groups
,
int64_t
base_width
)
:
stride
(
stride
),
downsample
(
downsample
)
{
if
(
groups
!=
1
||
base_width
!=
64
)
{
std
::
cerr
<<
"BasicBlock only supports groups=1 and base_width=64"
<<
std
::
endl
;
assert
(
false
);
}
TORCH_CHECK
(
groups
==
1
&&
base_width
==
64
,
"BasicBlock only supports groups=1 and base_width=64"
);
// Both conv1 and downsample layers downsample the input when stride != 1
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
);
...
...
torchvision/csrc/models/resnet.h
View file @
fecd1385
...
...
@@ -72,8 +72,8 @@ struct ResNetImpl : torch::nn::Module {
int64_t
groups
,
base_width
,
inplanes
;
torch
::
nn
::
Conv2d
conv1
;
torch
::
nn
::
BatchNorm
bn1
;
torch
::
nn
::
Linear
fc
;
torch
::
nn
::
Sequential
layer1
,
layer2
,
layer3
,
layer4
;
torch
::
nn
::
Linear
fc
;
torch
::
nn
::
Sequential
_make_layer
(
int64_t
planes
,
...
...
torchvision/csrc/models/shufflenetv2.cpp
View file @
fecd1385
...
...
@@ -41,13 +41,10 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
ShuffleNetV2InvertedResidualImpl
(
int64_t
inp
,
int64_t
oup
,
int64_t
stride
)
:
stride
(
stride
)
{
if
(
stride
<
1
||
stride
>
3
)
{
std
::
cerr
<<
"illegal stride value'"
<<
std
::
endl
;
assert
(
false
);
}
TORCH_CHECK
(
stride
>=
1
&&
stride
<=
3
,
"illegal stride value"
);
auto
branch_features
=
oup
/
2
;
assert
(
stride
!=
1
||
inp
==
branch_features
<<
1
);
TORCH_CHECK
(
stride
!=
1
||
inp
==
branch_features
<<
1
);
if
(
stride
>
1
)
{
branch1
=
torch
::
nn
::
Sequential
(
...
...
@@ -94,17 +91,13 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
const
std
::
vector
<
int64_t
>&
stage_repeats
,
const
std
::
vector
<
int64_t
>&
stage_out_channels
,
int64_t
num_classes
)
{
if
(
stage_repeats
.
size
()
!=
3
)
{
std
::
cerr
<<
"expected stage_repeats as vector of 3 positive ints"
<<
std
::
endl
;
assert
(
false
);
}
TORCH_CHECK
(
stage_repeats
.
size
()
==
3
,
"expected stage_repeats as vector of 3 positive ints"
);
if
(
stage_out_channels
.
size
()
!=
5
)
{
std
::
cerr
<<
"expected stage_out_channels as vector of 5 positive ints"
<<
std
::
endl
;
assert
(
false
);
}
TORCH_CHECK
(
stage_out_channels
.
size
()
==
5
,
"expected stage_out_channels as vector of 5 positive ints"
);
_stage_out_channels
=
stage_out_channels
;
int64_t
input_channels
=
3
;
...
...
torchvision/csrc/models/squeezenet.cpp
View file @
fecd1385
...
...
@@ -64,11 +64,12 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
Fire
(
384
,
48
,
192
,
192
),
Fire
(
384
,
64
,
256
,
256
),
Fire
(
512
,
64
,
256
,
256
));
}
else
{
std
::
cerr
<<
"Unsupported SqueezeNet version "
<<
version
<<
". 1_0 or 1_1 expected"
<<
std
::
endl
;
assert
(
false
);
}
}
else
TORCH_CHECK
(
false
,
"Unsupported SqueezeNet version "
,
version
,
". 1_0 or 1_1 expected"
);
// Final convolution is initialized differently from the rest
auto
final_conv
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment