Commit fecd1385 authored by Shahriar's avatar Shahriar Committed by Francisco Massa
Browse files

Update C++ Models to use TORCH_CHECK instead of asserts (#1144)

* Replaced asserts with TORCH_CHECK

* Fixed an error
parent 737966a3
......@@ -17,8 +17,8 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
int64_t stride,
double expansion_factor,
double bn_momentum = 0.1) {
assert(stride == 1 || stride == 2);
assert(kernel == 3 || kernel == 5);
TORCH_CHECK(stride == 1 || stride == 2);
TORCH_CHECK(kernel == 3 || kernel == 5);
auto mid = int64_t(input * expansion_factor);
apply_residual = input == output && stride == 1;
......@@ -74,7 +74,7 @@ StackSequentail stack(
double exp_factor,
int64_t repeats,
double bn_momentum) {
assert(repeats >= 1);
TORCH_CHECK(repeats >= 1);
StackSequentail seq;
seq->push_back(MNASNetInvertedResidual(
......@@ -91,7 +91,7 @@ int64_t round_to_multiple_of(
int64_t val,
int64_t divisor,
double round_up_bias = .9) {
assert(0.0 < round_up_bias && round_up_bias < 1.0);
TORCH_CHECK(0.0 < round_up_bias && round_up_bias < 1.0);
auto new_val = std::max(divisor, (val + divisor / 2) / divisor * divisor);
return new_val >= round_up_bias * val ? new_val : new_val + divisor;
}
......
......@@ -59,7 +59,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};
assert(stride == 1 || stride == 2);
TORCH_CHECK(stride == 1 || stride == 2);
auto hidden_dim = int64_t(std::round(input * expand_ratio));
if (!double_compare(expand_ratio, 1))
......@@ -103,10 +103,9 @@ MobileNetV2Impl::MobileNetV2Impl(
{6, 320, 1, 1},
};
if (inverted_residual_settings[0].size() != 4) {
std::cerr << "inverted_residual_settings should contain 4-element vectors";
assert(false);
}
TORCH_CHECK(
inverted_residual_settings[0].size() == 4,
"inverted_residual_settings should contain 4-element vectors");
input_channel = make_divisible(input_channel * width_mult, round_nearest);
this->last_channel =
......
......@@ -3,6 +3,10 @@
#include <torch/torch.h>
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
namespace vision {
namespace models {
namespace modelsimpl {
......
#include "resnet.h"
#include "modelsimpl.h"
namespace vision {
namespace models {
namespace _resnetimpl {
......@@ -30,11 +32,9 @@ BasicBlock::BasicBlock(
int64_t groups,
int64_t base_width)
: stride(stride), downsample(downsample) {
if (groups != 1 || base_width != 64) {
std::cerr << "BasicBlock only supports groups=1 and base_width=64"
<< std::endl;
assert(false);
}
TORCH_CHECK(
groups == 1 && base_width == 64,
"BasicBlock only supports groups=1 and base_width=64");
// Both conv1 and downsample layers downsample the input when stride != 1
conv1 = conv3x3(inplanes, planes, stride);
......
......@@ -72,8 +72,8 @@ struct ResNetImpl : torch::nn::Module {
int64_t groups, base_width, inplanes;
torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1;
torch::nn::Linear fc;
torch::nn::Sequential layer1, layer2, layer3, layer4;
torch::nn::Linear fc;
torch::nn::Sequential _make_layer(
int64_t planes,
......
......@@ -41,13 +41,10 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride)
: stride(stride) {
if (stride < 1 || stride > 3) {
std::cerr << "illegal stride value'" << std::endl;
assert(false);
}
TORCH_CHECK(stride >= 1 && stride <= 3, "illegal stride value");
auto branch_features = oup / 2;
assert(stride != 1 || inp == branch_features << 1);
TORCH_CHECK(stride != 1 || inp == branch_features << 1);
if (stride > 1) {
branch1 = torch::nn::Sequential(
......@@ -94,17 +91,13 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
const std::vector<int64_t>& stage_repeats,
const std::vector<int64_t>& stage_out_channels,
int64_t num_classes) {
if (stage_repeats.size() != 3) {
std::cerr << "expected stage_repeats as vector of 3 positive ints"
<< std::endl;
assert(false);
}
TORCH_CHECK(
stage_repeats.size() == 3,
"expected stage_repeats as vector of 3 positive ints");
if (stage_out_channels.size() != 5) {
std::cerr << "expected stage_out_channels as vector of 5 positive ints"
<< std::endl;
assert(false);
}
TORCH_CHECK(
stage_out_channels.size() == 5,
"expected stage_out_channels as vector of 5 positive ints");
_stage_out_channels = stage_out_channels;
int64_t input_channels = 3;
......
......@@ -64,11 +64,12 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256));
} else {
std::cerr << "Unsupported SqueezeNet version " << version
<< ". 1_0 or 1_1 expected" << std::endl;
assert(false);
}
} else
TORCH_CHECK(
false,
"Unsupported SqueezeNet version ",
version,
". 1_0 or 1_1 expected");
// Final convolution is initialized differently from the rest
auto final_conv =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment