Commit fecd1385 authored by Shahriar's avatar Shahriar Committed by Francisco Massa
Browse files

Update C++ Models to use TORCH_CHECK instead of asserts (#1144)

* Replaced asserts with TORCH_CHECK

* Fixed an error
parent 737966a3
...@@ -17,8 +17,8 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module { ...@@ -17,8 +17,8 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
int64_t stride, int64_t stride,
double expansion_factor, double expansion_factor,
double bn_momentum = 0.1) { double bn_momentum = 0.1) {
assert(stride == 1 || stride == 2); TORCH_CHECK(stride == 1 || stride == 2);
assert(kernel == 3 || kernel == 5); TORCH_CHECK(kernel == 3 || kernel == 5);
auto mid = int64_t(input * expansion_factor); auto mid = int64_t(input * expansion_factor);
apply_residual = input == output && stride == 1; apply_residual = input == output && stride == 1;
...@@ -74,7 +74,7 @@ StackSequentail stack( ...@@ -74,7 +74,7 @@ StackSequentail stack(
double exp_factor, double exp_factor,
int64_t repeats, int64_t repeats,
double bn_momentum) { double bn_momentum) {
assert(repeats >= 1); TORCH_CHECK(repeats >= 1);
StackSequentail seq; StackSequentail seq;
seq->push_back(MNASNetInvertedResidual( seq->push_back(MNASNetInvertedResidual(
...@@ -91,7 +91,7 @@ int64_t round_to_multiple_of( ...@@ -91,7 +91,7 @@ int64_t round_to_multiple_of(
int64_t val, int64_t val,
int64_t divisor, int64_t divisor,
double round_up_bias = .9) { double round_up_bias = .9) {
assert(0.0 < round_up_bias && round_up_bias < 1.0); TORCH_CHECK(0.0 < round_up_bias && round_up_bias < 1.0);
auto new_val = std::max(divisor, (val + divisor / 2) / divisor * divisor); auto new_val = std::max(divisor, (val + divisor / 2) / divisor * divisor);
return new_val >= round_up_bias * val ? new_val : new_val + divisor; return new_val >= round_up_bias * val ? new_val : new_val + divisor;
} }
......
...@@ -59,7 +59,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module { ...@@ -59,7 +59,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon(); return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
}; };
assert(stride == 1 || stride == 2); TORCH_CHECK(stride == 1 || stride == 2);
auto hidden_dim = int64_t(std::round(input * expand_ratio)); auto hidden_dim = int64_t(std::round(input * expand_ratio));
if (!double_compare(expand_ratio, 1)) if (!double_compare(expand_ratio, 1))
...@@ -103,10 +103,9 @@ MobileNetV2Impl::MobileNetV2Impl( ...@@ -103,10 +103,9 @@ MobileNetV2Impl::MobileNetV2Impl(
{6, 320, 1, 1}, {6, 320, 1, 1},
}; };
if (inverted_residual_settings[0].size() != 4) { TORCH_CHECK(
std::cerr << "inverted_residual_settings should contain 4-element vectors"; inverted_residual_settings[0].size() == 4,
assert(false); "inverted_residual_settings should contain 4-element vectors");
}
input_channel = make_divisible(input_channel * width_mult, round_nearest); input_channel = make_divisible(input_channel * width_mult, round_nearest);
this->last_channel = this->last_channel =
......
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
#include <torch/torch.h> #include <torch/torch.h>
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
namespace vision { namespace vision {
namespace models { namespace models {
namespace modelsimpl { namespace modelsimpl {
......
#include "resnet.h" #include "resnet.h"
#include "modelsimpl.h"
namespace vision { namespace vision {
namespace models { namespace models {
namespace _resnetimpl { namespace _resnetimpl {
...@@ -30,11 +32,9 @@ BasicBlock::BasicBlock( ...@@ -30,11 +32,9 @@ BasicBlock::BasicBlock(
int64_t groups, int64_t groups,
int64_t base_width) int64_t base_width)
: stride(stride), downsample(downsample) { : stride(stride), downsample(downsample) {
if (groups != 1 || base_width != 64) { TORCH_CHECK(
std::cerr << "BasicBlock only supports groups=1 and base_width=64" groups == 1 && base_width == 64,
<< std::endl; "BasicBlock only supports groups=1 and base_width=64");
assert(false);
}
// Both conv1 and downsample layers downsample the input when stride != 1 // Both conv1 and downsample layers downsample the input when stride != 1
conv1 = conv3x3(inplanes, planes, stride); conv1 = conv3x3(inplanes, planes, stride);
......
...@@ -72,8 +72,8 @@ struct ResNetImpl : torch::nn::Module { ...@@ -72,8 +72,8 @@ struct ResNetImpl : torch::nn::Module {
int64_t groups, base_width, inplanes; int64_t groups, base_width, inplanes;
torch::nn::Conv2d conv1; torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1; torch::nn::BatchNorm bn1;
torch::nn::Linear fc;
torch::nn::Sequential layer1, layer2, layer3, layer4; torch::nn::Sequential layer1, layer2, layer3, layer4;
torch::nn::Linear fc;
torch::nn::Sequential _make_layer( torch::nn::Sequential _make_layer(
int64_t planes, int64_t planes,
......
...@@ -41,13 +41,10 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module { ...@@ -41,13 +41,10 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride) ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride)
: stride(stride) { : stride(stride) {
if (stride < 1 || stride > 3) { TORCH_CHECK(stride >= 1 && stride <= 3, "illegal stride value");
std::cerr << "illegal stride value'" << std::endl;
assert(false);
}
auto branch_features = oup / 2; auto branch_features = oup / 2;
assert(stride != 1 || inp == branch_features << 1); TORCH_CHECK(stride != 1 || inp == branch_features << 1);
if (stride > 1) { if (stride > 1) {
branch1 = torch::nn::Sequential( branch1 = torch::nn::Sequential(
...@@ -94,17 +91,13 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( ...@@ -94,17 +91,13 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
const std::vector<int64_t>& stage_repeats, const std::vector<int64_t>& stage_repeats,
const std::vector<int64_t>& stage_out_channels, const std::vector<int64_t>& stage_out_channels,
int64_t num_classes) { int64_t num_classes) {
if (stage_repeats.size() != 3) { TORCH_CHECK(
std::cerr << "expected stage_repeats as vector of 3 positive ints" stage_repeats.size() == 3,
<< std::endl; "expected stage_repeats as vector of 3 positive ints");
assert(false);
}
if (stage_out_channels.size() != 5) { TORCH_CHECK(
std::cerr << "expected stage_out_channels as vector of 5 positive ints" stage_out_channels.size() == 5,
<< std::endl; "expected stage_out_channels as vector of 5 positive ints");
assert(false);
}
_stage_out_channels = stage_out_channels; _stage_out_channels = stage_out_channels;
int64_t input_channels = 3; int64_t input_channels = 3;
......
...@@ -64,11 +64,12 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes) ...@@ -64,11 +64,12 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
Fire(384, 48, 192, 192), Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256), Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256)); Fire(512, 64, 256, 256));
} else { } else
std::cerr << "Unsupported SqueezeNet version " << version TORCH_CHECK(
<< ". 1_0 or 1_1 expected" << std::endl; false,
assert(false); "Unsupported SqueezeNet version ",
} version,
". 1_0 or 1_1 expected");
// Final convolution is initialized differently from the rest // Final convolution is initialized differently from the rest
auto final_conv = auto final_conv =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment