Unverified Commit 44a5bae9 authored by Will Feng's avatar Will Feng Committed by GitHub
Browse files

Rename with_bias() to bias(), and output_channels() to out_channels() in C++...

Rename with_bias() to bias(), and output_channels() to out_channels() in C++ conv layer options usage (#1576)
parent 681c6c11
......@@ -21,7 +21,7 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
"conv1",
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
.stride(1)
.with_bias(false)));
.bias(false)));
push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate));
push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
push_back(
......@@ -29,7 +29,7 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3)
.stride(1)
.padding(1)
.with_bias(false)));
.bias(false)));
}
torch::Tensor forward(torch::Tensor x) {
......@@ -75,7 +75,7 @@ struct _TransitionImpl : torch::nn::SequentialImpl {
"conv",
torch::nn::Conv2d(Options(num_input_features, num_output_features, 1)
.stride(1)
.with_bias(false)));
.bias(false)));
push_back("pool", torch::nn::Functional([](torch::Tensor input) {
return torch::avg_pool2d(input, 2, 2, 0, false, true);
}));
......@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
torch::nn::Conv2d(Options(3, num_init_features, 7)
.stride(2)
.padding(3)
.with_bias(false)));
.bias(false)));
features->push_back("norm0", torch::nn::BatchNorm(num_init_features));
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
......
......@@ -9,10 +9,10 @@ using Options = torch::nn::Conv2dOptions;
namespace _googlenetimpl {
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
options.with_bias(false);
options.bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm(
torch::nn::BatchNormOptions(options.output_channels()).eps(0.001));
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv);
register_module("bn", bn);
......
......@@ -9,10 +9,10 @@ namespace _inceptionimpl {
BasicConv2dImpl::BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev) {
options.with_bias(false);
options.bias(false);
conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm(
torch::nn::BatchNormOptions(options.output_channels()).eps(0.001));
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv);
register_module("bn", bn);
......
......@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
apply_residual = input == output && stride == 1;
layers->push_back(
torch::nn::Conv2d(Options(input, mid, 1).with_bias(false)));
torch::nn::Conv2d(Options(input, mid, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back(
......@@ -34,13 +34,13 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
.padding(kernel / 2)
.stride(stride)
.groups(mid)
.with_bias(false))));
.bias(false))));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
layers->push_back(
torch::nn::Conv2d(Options(mid, output, 1).with_bias(false)));
torch::nn::Conv2d(Options(mid, output, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(output).momentum(bn_momentum)));
......@@ -129,17 +129,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
auto depths = scale_depths({24, 40, 80, 96, 192, 320}, alpha);
layers->push_back(torch::nn::Conv2d(
Options(3, 32, 3).padding(1).stride(2).with_bias(false)));
Options(3, 32, 3).padding(1).stride(2).bias(false)));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d(
Options(32, 32, 3).padding(1).stride(1).groups(32).with_bias(false)));
Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false)));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d(
Options(32, 16, 1).padding(0).stride(1).with_bias(false)));
Options(32, 16, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM)));
......@@ -151,7 +151,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
layers->push_back(stack(depths[4], depths[5], 3, 1, 6, 1, BN_MOMENTUM));
layers->push_back(torch::nn::Conv2d(
Options(depths[5], 1280, 1).padding(0).stride(1).with_bias(false)));
Options(depths[5], 1280, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_));
......
......@@ -32,7 +32,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
.stride(stride)
.padding(padding)
.groups(groups)
.with_bias(false)));
.bias(false)));
push_back(torch::nn::BatchNorm(out_planes));
push_back(torch::nn::Functional(modelsimpl::relu6_));
}
......@@ -67,7 +67,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
conv->push_back(torch::nn::Conv2d(
Options(hidden_dim, output, 1).stride(1).padding(0).with_bias(false)));
Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
conv->push_back(torch::nn::BatchNorm(output));
register_module("conv", conv);
......@@ -136,7 +136,7 @@ MobileNetV2Impl::MobileNetV2Impl(
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut);
if (M->options.with_bias())
if (M->options.bias())
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
......
......@@ -11,13 +11,13 @@ torch::nn::Conv2d conv3x3(
int64_t stride,
int64_t groups) {
torch::nn::Conv2dOptions O(in, out, 3);
O.padding(1).stride(stride).groups(groups).with_bias(false);
O.padding(1).stride(stride).groups(groups).bias(false);
return torch::nn::Conv2d(O);
}
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) {
torch::nn::Conv2dOptions O(in, out, 1);
O.stride(stride).with_bias(false);
O.stride(stride).bias(false);
return torch::nn::Conv2d(O);
}
......
......@@ -124,7 +124,7 @@ ResNetImpl<Block>::ResNetImpl(
: groups(groups),
base_width(width_per_group),
inplanes(64),
conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).with_bias(
conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(
false)),
bn1(64),
layer1(_make_layer(64, layers[0])),
......
......@@ -25,13 +25,13 @@ torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) {
torch::nn::Conv2d conv11(int64_t input, int64_t output) {
Options opts(input, output, 1);
opts = opts.stride(1).padding(0).with_bias(false);
opts = opts.stride(1).padding(0).bias(false);
return torch::nn::Conv2d(opts);
}
torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) {
Options opts(input, output, 3);
opts = opts.stride(stride).padding(1).with_bias(false).groups(input);
opts = opts.stride(stride).padding(1).bias(false).groups(input);
return torch::nn::Conv2d(opts);
}
......@@ -107,7 +107,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
torch::nn::Conv2d(Options(input_channels, output_channels, 3)
.stride(2)
.padding(1)
.with_bias(false)),
.bias(false)),
torch::nn::BatchNorm(output_channels),
torch::nn::Functional(modelsimpl::relu_));
......@@ -134,7 +134,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
torch::nn::Conv2d(Options(input_channels, output_channels, 1)
.stride(1)
.padding(0)
.with_bias(false)),
.bias(false)),
torch::nn::BatchNorm(output_channels),
torch::nn::Functional(modelsimpl::relu_));
......
......@@ -91,7 +91,7 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
else
torch::nn::init::kaiming_uniform_(M->weight);
if (M->options.with_bias())
if (M->options.bias())
torch::nn::init::constant_(M->bias, 0);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment