Unverified Commit 44a5bae9 authored by Will Feng's avatar Will Feng Committed by GitHub
Browse files

Rename with_bias() to bias(), and output_channels() to out_channels() in C++...

Rename with_bias() to bias(), and output_channels() to out_channels() in C++ conv layer options usage (#1576)
parent 681c6c11
...@@ -21,7 +21,7 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl { ...@@ -21,7 +21,7 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
"conv1", "conv1",
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1) torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
.stride(1) .stride(1)
.with_bias(false))); .bias(false)));
push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate)); push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate));
push_back("relu2", torch::nn::Functional(modelsimpl::relu_)); push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
push_back( push_back(
...@@ -29,7 +29,7 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl { ...@@ -29,7 +29,7 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3) torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3)
.stride(1) .stride(1)
.padding(1) .padding(1)
.with_bias(false))); .bias(false)));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
...@@ -75,7 +75,7 @@ struct _TransitionImpl : torch::nn::SequentialImpl { ...@@ -75,7 +75,7 @@ struct _TransitionImpl : torch::nn::SequentialImpl {
"conv", "conv",
torch::nn::Conv2d(Options(num_input_features, num_output_features, 1) torch::nn::Conv2d(Options(num_input_features, num_output_features, 1)
.stride(1) .stride(1)
.with_bias(false))); .bias(false)));
push_back("pool", torch::nn::Functional([](torch::Tensor input) { push_back("pool", torch::nn::Functional([](torch::Tensor input) {
return torch::avg_pool2d(input, 2, 2, 0, false, true); return torch::avg_pool2d(input, 2, 2, 0, false, true);
})); }));
...@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl( ...@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
torch::nn::Conv2d(Options(3, num_init_features, 7) torch::nn::Conv2d(Options(3, num_init_features, 7)
.stride(2) .stride(2)
.padding(3) .padding(3)
.with_bias(false))); .bias(false)));
features->push_back("norm0", torch::nn::BatchNorm(num_init_features)); features->push_back("norm0", torch::nn::BatchNorm(num_init_features));
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_)); features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
......
...@@ -9,10 +9,10 @@ using Options = torch::nn::Conv2dOptions; ...@@ -9,10 +9,10 @@ using Options = torch::nn::Conv2dOptions;
namespace _googlenetimpl { namespace _googlenetimpl {
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) { BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
options.with_bias(false); options.bias(false);
conv = torch::nn::Conv2d(options); conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm( bn = torch::nn::BatchNorm(
torch::nn::BatchNormOptions(options.output_channels()).eps(0.001)); torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv); register_module("conv", conv);
register_module("bn", bn); register_module("bn", bn);
......
...@@ -9,10 +9,10 @@ namespace _inceptionimpl { ...@@ -9,10 +9,10 @@ namespace _inceptionimpl {
BasicConv2dImpl::BasicConv2dImpl( BasicConv2dImpl::BasicConv2dImpl(
torch::nn::Conv2dOptions options, torch::nn::Conv2dOptions options,
double std_dev) { double std_dev) {
options.with_bias(false); options.bias(false);
conv = torch::nn::Conv2d(options); conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm( bn = torch::nn::BatchNorm(
torch::nn::BatchNormOptions(options.output_channels()).eps(0.001)); torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv); register_module("conv", conv);
register_module("bn", bn); register_module("bn", bn);
......
...@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module { ...@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
apply_residual = input == output && stride == 1; apply_residual = input == output && stride == 1;
layers->push_back( layers->push_back(
torch::nn::Conv2d(Options(input, mid, 1).with_bias(false))); torch::nn::Conv2d(Options(input, mid, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum))); torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back( layers->push_back(
...@@ -34,13 +34,13 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module { ...@@ -34,13 +34,13 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
.padding(kernel / 2) .padding(kernel / 2)
.stride(stride) .stride(stride)
.groups(mid) .groups(mid)
.with_bias(false)))); .bias(false))));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum))); torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back( layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_))); torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
layers->push_back( layers->push_back(
torch::nn::Conv2d(Options(mid, output, 1).with_bias(false))); torch::nn::Conv2d(Options(mid, output, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(output).momentum(bn_momentum))); torch::nn::BatchNormOptions(output).momentum(bn_momentum)));
...@@ -129,17 +129,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) { ...@@ -129,17 +129,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
auto depths = scale_depths({24, 40, 80, 96, 192, 320}, alpha); auto depths = scale_depths({24, 40, 80, 96, 192, 320}, alpha);
layers->push_back(torch::nn::Conv2d( layers->push_back(torch::nn::Conv2d(
Options(3, 32, 3).padding(1).stride(2).with_bias(false))); Options(3, 32, 3).padding(1).stride(2).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_)); layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d( layers->push_back(torch::nn::Conv2d(
Options(32, 32, 3).padding(1).stride(1).groups(32).with_bias(false))); Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_)); layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d( layers->push_back(torch::nn::Conv2d(
Options(32, 16, 1).padding(0).stride(1).with_bias(false))); Options(32, 16, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM)));
...@@ -151,7 +151,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) { ...@@ -151,7 +151,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
layers->push_back(stack(depths[4], depths[5], 3, 1, 6, 1, BN_MOMENTUM)); layers->push_back(stack(depths[4], depths[5], 3, 1, 6, 1, BN_MOMENTUM));
layers->push_back(torch::nn::Conv2d( layers->push_back(torch::nn::Conv2d(
Options(depths[5], 1280, 1).padding(0).stride(1).with_bias(false))); Options(depths[5], 1280, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm(
torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_)); layers->push_back(torch::nn::Functional(modelsimpl::relu_));
......
...@@ -32,7 +32,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl { ...@@ -32,7 +32,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
.stride(stride) .stride(stride)
.padding(padding) .padding(padding)
.groups(groups) .groups(groups)
.with_bias(false))); .bias(false)));
push_back(torch::nn::BatchNorm(out_planes)); push_back(torch::nn::BatchNorm(out_planes));
push_back(torch::nn::Functional(modelsimpl::relu6_)); push_back(torch::nn::Functional(modelsimpl::relu6_));
} }
...@@ -67,7 +67,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module { ...@@ -67,7 +67,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim)); conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
conv->push_back(torch::nn::Conv2d( conv->push_back(torch::nn::Conv2d(
Options(hidden_dim, output, 1).stride(1).padding(0).with_bias(false))); Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
conv->push_back(torch::nn::BatchNorm(output)); conv->push_back(torch::nn::BatchNorm(output));
register_module("conv", conv); register_module("conv", conv);
...@@ -136,7 +136,7 @@ MobileNetV2Impl::MobileNetV2Impl( ...@@ -136,7 +136,7 @@ MobileNetV2Impl::MobileNetV2Impl(
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) { if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut); M->weight, 0, torch::nn::init::FanMode::FanOut);
if (M->options.with_bias()) if (M->options.bias())
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
......
...@@ -11,13 +11,13 @@ torch::nn::Conv2d conv3x3( ...@@ -11,13 +11,13 @@ torch::nn::Conv2d conv3x3(
int64_t stride, int64_t stride,
int64_t groups) { int64_t groups) {
torch::nn::Conv2dOptions O(in, out, 3); torch::nn::Conv2dOptions O(in, out, 3);
O.padding(1).stride(stride).groups(groups).with_bias(false); O.padding(1).stride(stride).groups(groups).bias(false);
return torch::nn::Conv2d(O); return torch::nn::Conv2d(O);
} }
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) { torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) {
torch::nn::Conv2dOptions O(in, out, 1); torch::nn::Conv2dOptions O(in, out, 1);
O.stride(stride).with_bias(false); O.stride(stride).bias(false);
return torch::nn::Conv2d(O); return torch::nn::Conv2d(O);
} }
......
...@@ -124,7 +124,7 @@ ResNetImpl<Block>::ResNetImpl( ...@@ -124,7 +124,7 @@ ResNetImpl<Block>::ResNetImpl(
: groups(groups), : groups(groups),
base_width(width_per_group), base_width(width_per_group),
inplanes(64), inplanes(64),
conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).with_bias( conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(
false)), false)),
bn1(64), bn1(64),
layer1(_make_layer(64, layers[0])), layer1(_make_layer(64, layers[0])),
......
...@@ -25,13 +25,13 @@ torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) { ...@@ -25,13 +25,13 @@ torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) {
torch::nn::Conv2d conv11(int64_t input, int64_t output) { torch::nn::Conv2d conv11(int64_t input, int64_t output) {
Options opts(input, output, 1); Options opts(input, output, 1);
opts = opts.stride(1).padding(0).with_bias(false); opts = opts.stride(1).padding(0).bias(false);
return torch::nn::Conv2d(opts); return torch::nn::Conv2d(opts);
} }
torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) { torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) {
Options opts(input, output, 3); Options opts(input, output, 3);
opts = opts.stride(stride).padding(1).with_bias(false).groups(input); opts = opts.stride(stride).padding(1).bias(false).groups(input);
return torch::nn::Conv2d(opts); return torch::nn::Conv2d(opts);
} }
...@@ -107,7 +107,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( ...@@ -107,7 +107,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
torch::nn::Conv2d(Options(input_channels, output_channels, 3) torch::nn::Conv2d(Options(input_channels, output_channels, 3)
.stride(2) .stride(2)
.padding(1) .padding(1)
.with_bias(false)), .bias(false)),
torch::nn::BatchNorm(output_channels), torch::nn::BatchNorm(output_channels),
torch::nn::Functional(modelsimpl::relu_)); torch::nn::Functional(modelsimpl::relu_));
...@@ -134,7 +134,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( ...@@ -134,7 +134,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
torch::nn::Conv2d(Options(input_channels, output_channels, 1) torch::nn::Conv2d(Options(input_channels, output_channels, 1)
.stride(1) .stride(1)
.padding(0) .padding(0)
.with_bias(false)), .bias(false)),
torch::nn::BatchNorm(output_channels), torch::nn::BatchNorm(output_channels),
torch::nn::Functional(modelsimpl::relu_)); torch::nn::Functional(modelsimpl::relu_));
......
...@@ -91,7 +91,7 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes) ...@@ -91,7 +91,7 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
else else
torch::nn::init::kaiming_uniform_(M->weight); torch::nn::init::kaiming_uniform_(M->weight);
if (M->options.with_bias()) if (M->options.bias())
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment