Unverified Commit b6f28ec1 authored by Francis Charette Migneault's avatar Francis Charette Migneault Committed by GitHub
Browse files

replace torch 1.5.0 items flagged with deprecation warnings (fix #1906) (#1918)

parent 6aa99ced
...@@ -15,14 +15,14 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl { ...@@ -15,14 +15,14 @@ struct _DenseLayerImpl : torch::nn::SequentialImpl {
int64_t bn_size, int64_t bn_size,
double drop_rate) double drop_rate)
: drop_rate(drop_rate) { : drop_rate(drop_rate) {
push_back("norm1", torch::nn::BatchNorm(num_input_features)); push_back("norm1", torch::nn::BatchNorm2d(num_input_features));
push_back("relu1", torch::nn::Functional(modelsimpl::relu_)); push_back("relu1", torch::nn::Functional(modelsimpl::relu_));
push_back( push_back(
"conv1", "conv1",
torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1) torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1)
.stride(1) .stride(1)
.bias(false))); .bias(false)));
push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate)); push_back("norm2", torch::nn::BatchNorm2d(bn_size * growth_rate));
push_back("relu2", torch::nn::Functional(modelsimpl::relu_)); push_back("relu2", torch::nn::Functional(modelsimpl::relu_));
push_back( push_back(
"conv2", "conv2",
...@@ -69,7 +69,7 @@ TORCH_MODULE(_DenseBlock); ...@@ -69,7 +69,7 @@ TORCH_MODULE(_DenseBlock);
struct _TransitionImpl : torch::nn::SequentialImpl { struct _TransitionImpl : torch::nn::SequentialImpl {
_TransitionImpl(int64_t num_input_features, int64_t num_output_features) { _TransitionImpl(int64_t num_input_features, int64_t num_output_features) {
push_back("norm", torch::nn::BatchNorm(num_input_features)); push_back("norm", torch::nn::BatchNorm2d(num_input_features));
push_back("relu ", torch::nn::Functional(modelsimpl::relu_)); push_back("relu ", torch::nn::Functional(modelsimpl::relu_));
push_back( push_back(
"conv", "conv",
...@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl( ...@@ -102,7 +102,7 @@ DenseNetImpl::DenseNetImpl(
torch::nn::Conv2d( torch::nn::Conv2d(
Options(3, num_init_features, 7).stride(2).padding(3).bias(false))); Options(3, num_init_features, 7).stride(2).padding(3).bias(false)));
features->push_back("norm0", torch::nn::BatchNorm(num_init_features)); features->push_back("norm0", torch::nn::BatchNorm2d(num_init_features));
features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_)); features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_));
features->push_back( features->push_back(
"pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false)); "pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false));
...@@ -125,7 +125,7 @@ DenseNetImpl::DenseNetImpl( ...@@ -125,7 +125,7 @@ DenseNetImpl::DenseNetImpl(
} }
// Final batch norm // Final batch norm
features->push_back("norm5", torch::nn::BatchNorm(num_features)); features->push_back("norm5", torch::nn::BatchNorm2d(num_features));
// Linear layer // Linear layer
classifier = torch::nn::Linear(num_features, num_classes); classifier = torch::nn::Linear(num_features, num_classes);
...@@ -136,7 +136,7 @@ DenseNetImpl::DenseNetImpl( ...@@ -136,7 +136,7 @@ DenseNetImpl::DenseNetImpl(
for (auto& module : modules(/*include_self=*/false)) { for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(M->weight); torch::nn::init::kaiming_normal_(M->weight);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
......
...@@ -11,7 +11,7 @@ namespace _googlenetimpl { ...@@ -11,7 +11,7 @@ namespace _googlenetimpl {
BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) { BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) {
options.bias(false); options.bias(false);
conv = torch::nn::Conv2d(options); conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm( bn = torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001)); torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv); register_module("conv", conv);
...@@ -155,7 +155,7 @@ void GoogLeNetImpl::_initialize_weights() { ...@@ -155,7 +155,7 @@ void GoogLeNetImpl::_initialize_weights() {
else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get()))
torch::nn::init::normal_(M->weight); // Note: used instead of truncated torch::nn::init::normal_(M->weight); // Note: used instead of truncated
// normal initialization // normal initialization
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} }
......
...@@ -10,7 +10,7 @@ namespace models { ...@@ -10,7 +10,7 @@ namespace models {
namespace _googlenetimpl { namespace _googlenetimpl {
struct VISION_API BasicConv2dImpl : torch::nn::Module { struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr}; torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr}; torch::nn::BatchNorm2d bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options); BasicConv2dImpl(torch::nn::Conv2dOptions options);
......
...@@ -11,7 +11,7 @@ BasicConv2dImpl::BasicConv2dImpl( ...@@ -11,7 +11,7 @@ BasicConv2dImpl::BasicConv2dImpl(
double std_dev) { double std_dev) {
options.bias(false); options.bias(false);
conv = torch::nn::Conv2d(options); conv = torch::nn::Conv2d(options);
bn = torch::nn::BatchNorm( bn = torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(options.out_channels()).eps(0.001)); torch::nn::BatchNormOptions(options.out_channels()).eps(0.001));
register_module("conv", conv); register_module("conv", conv);
......
...@@ -9,7 +9,7 @@ namespace models { ...@@ -9,7 +9,7 @@ namespace models {
namespace _inceptionimpl { namespace _inceptionimpl {
struct VISION_API BasicConv2dImpl : torch::nn::Module { struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr}; torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm bn{nullptr}; torch::nn::BatchNorm2d bn{nullptr};
BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1); BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);
......
...@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module { ...@@ -24,7 +24,7 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
apply_residual = input == output && stride == 1; apply_residual = input == output && stride == 1;
layers->push_back(torch::nn::Conv2d(Options(input, mid, 1).bias(false))); layers->push_back(torch::nn::Conv2d(Options(input, mid, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum))); torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back( layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_))); torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
...@@ -34,12 +34,12 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module { ...@@ -34,12 +34,12 @@ struct MNASNetInvertedResidualImpl : torch::nn::Module {
.stride(stride) .stride(stride)
.groups(mid) .groups(mid)
.bias(false)))); .bias(false))));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(mid).momentum(bn_momentum))); torch::nn::BatchNormOptions(mid).momentum(bn_momentum)));
layers->push_back( layers->push_back(
torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_))); torch::nn::Functional(torch::nn::Functional(modelsimpl::relu_)));
layers->push_back(torch::nn::Conv2d(Options(mid, output, 1).bias(false))); layers->push_back(torch::nn::Conv2d(Options(mid, output, 1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(output).momentum(bn_momentum))); torch::nn::BatchNormOptions(output).momentum(bn_momentum)));
register_module("layers", layers); register_module("layers", layers);
...@@ -109,9 +109,9 @@ void MNASNetImpl::_initialize_weights() { ...@@ -109,9 +109,9 @@ void MNASNetImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight,
0, 0,
torch::nn::init::FanMode::FanOut, torch::kFanOut,
torch::nn::init::Nonlinearity::ReLU); torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
...@@ -128,17 +128,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) { ...@@ -128,17 +128,17 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
layers->push_back( layers->push_back(
torch::nn::Conv2d(Options(3, 32, 3).padding(1).stride(2).bias(false))); torch::nn::Conv2d(Options(3, 32, 3).padding(1).stride(2).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_)); layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back(torch::nn::Conv2d( layers->push_back(torch::nn::Conv2d(
Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false))); Options(32, 32, 3).padding(1).stride(1).groups(32).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(32).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_)); layers->push_back(torch::nn::Functional(modelsimpl::relu_));
layers->push_back( layers->push_back(
torch::nn::Conv2d(Options(32, 16, 1).padding(0).stride(1).bias(false))); torch::nn::Conv2d(Options(32, 16, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(16).momentum(BN_MOMENTUM)));
layers->push_back(stack(16, depths[0], 3, 2, 3, 3, BN_MOMENTUM)); layers->push_back(stack(16, depths[0], 3, 2, 3, 3, BN_MOMENTUM));
...@@ -150,7 +150,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) { ...@@ -150,7 +150,7 @@ MNASNetImpl::MNASNetImpl(double alpha, int64_t num_classes, double dropout) {
layers->push_back(torch::nn::Conv2d( layers->push_back(torch::nn::Conv2d(
Options(depths[5], 1280, 1).padding(0).stride(1).bias(false))); Options(depths[5], 1280, 1).padding(0).stride(1).bias(false)));
layers->push_back(torch::nn::BatchNorm( layers->push_back(torch::nn::BatchNorm2d(
torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM))); torch::nn::BatchNormOptions(1280).momentum(BN_MOMENTUM)));
layers->push_back(torch::nn::Functional(modelsimpl::relu_)); layers->push_back(torch::nn::Functional(modelsimpl::relu_));
......
...@@ -33,7 +33,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl { ...@@ -33,7 +33,7 @@ struct ConvBNReLUImpl : torch::nn::SequentialImpl {
.padding(padding) .padding(padding)
.groups(groups) .groups(groups)
.bias(false))); .bias(false)));
push_back(torch::nn::BatchNorm(out_planes)); push_back(torch::nn::BatchNorm2d(out_planes));
push_back(torch::nn::Functional(modelsimpl::relu6_)); push_back(torch::nn::Functional(modelsimpl::relu6_));
} }
...@@ -68,7 +68,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module { ...@@ -68,7 +68,7 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim)); conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
conv->push_back(torch::nn::Conv2d( conv->push_back(torch::nn::Conv2d(
Options(hidden_dim, output, 1).stride(1).padding(0).bias(false))); Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
conv->push_back(torch::nn::BatchNorm(output)); conv->push_back(torch::nn::BatchNorm2d(output));
register_module("conv", conv); register_module("conv", conv);
} }
...@@ -135,10 +135,10 @@ MobileNetV2Impl::MobileNetV2Impl( ...@@ -135,10 +135,10 @@ MobileNetV2Impl::MobileNetV2Impl(
for (auto& module : modules(/*include_self=*/false)) { for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) { if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut); M->weight, 0, torch::kFanOut);
if (M->options.bias()) if (M->options.bias())
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
......
...@@ -40,8 +40,8 @@ BasicBlock::BasicBlock( ...@@ -40,8 +40,8 @@ BasicBlock::BasicBlock(
conv1 = conv3x3(inplanes, planes, stride); conv1 = conv3x3(inplanes, planes, stride);
conv2 = conv3x3(planes, planes); conv2 = conv3x3(planes, planes);
bn1 = torch::nn::BatchNorm(planes); bn1 = torch::nn::BatchNorm2d(planes);
bn2 = torch::nn::BatchNorm(planes); bn2 = torch::nn::BatchNorm2d(planes);
register_module("conv1", conv1); register_module("conv1", conv1);
register_module("conv2", conv2); register_module("conv2", conv2);
...@@ -68,9 +68,9 @@ Bottleneck::Bottleneck( ...@@ -68,9 +68,9 @@ Bottleneck::Bottleneck(
conv2 = conv3x3(width, width, stride, groups); conv2 = conv3x3(width, width, stride, groups);
conv3 = conv1x1(width, planes * expansion); conv3 = conv1x1(width, planes * expansion);
bn1 = torch::nn::BatchNorm(width); bn1 = torch::nn::BatchNorm2d(width);
bn2 = torch::nn::BatchNorm(width); bn2 = torch::nn::BatchNorm2d(width);
bn3 = torch::nn::BatchNorm(planes * expansion); bn3 = torch::nn::BatchNorm2d(planes * expansion);
register_module("conv1", conv1); register_module("conv1", conv1);
register_module("conv2", conv2); register_module("conv2", conv2);
......
...@@ -28,7 +28,7 @@ struct VISION_API BasicBlock : torch::nn::Module { ...@@ -28,7 +28,7 @@ struct VISION_API BasicBlock : torch::nn::Module {
torch::nn::Sequential downsample; torch::nn::Sequential downsample;
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}; torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}; torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};
static int expansion; static int expansion;
...@@ -51,7 +51,7 @@ struct VISION_API Bottleneck : torch::nn::Module { ...@@ -51,7 +51,7 @@ struct VISION_API Bottleneck : torch::nn::Module {
torch::nn::Sequential downsample; torch::nn::Sequential downsample;
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}, bn3{nullptr}; torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
static int expansion; static int expansion;
...@@ -71,7 +71,7 @@ template <typename Block> ...@@ -71,7 +71,7 @@ template <typename Block>
struct ResNetImpl : torch::nn::Module { struct ResNetImpl : torch::nn::Module {
int64_t groups, base_width, inplanes; int64_t groups, base_width, inplanes;
torch::nn::Conv2d conv1; torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1; torch::nn::BatchNorm2d bn1;
torch::nn::Sequential layer1, layer2, layer3, layer4; torch::nn::Sequential layer1, layer2, layer3, layer4;
torch::nn::Linear fc; torch::nn::Linear fc;
...@@ -99,7 +99,7 @@ torch::nn::Sequential ResNetImpl<Block>::_make_layer( ...@@ -99,7 +99,7 @@ torch::nn::Sequential ResNetImpl<Block>::_make_layer(
if (stride != 1 || inplanes != planes * Block::expansion) { if (stride != 1 || inplanes != planes * Block::expansion) {
downsample = torch::nn::Sequential( downsample = torch::nn::Sequential(
_resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride), _resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
torch::nn::BatchNorm(planes * Block::expansion)); torch::nn::BatchNorm2d(planes * Block::expansion));
} }
torch::nn::Sequential layers; torch::nn::Sequential layers;
...@@ -146,9 +146,9 @@ ResNetImpl<Block>::ResNetImpl( ...@@ -146,9 +146,9 @@ ResNetImpl<Block>::ResNetImpl(
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight,
/*a=*/0, /*a=*/0,
torch::nn::init::FanMode::FanOut, torch::kFanOut,
torch::nn::init::Nonlinearity::ReLU); torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} }
......
...@@ -49,20 +49,20 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module { ...@@ -49,20 +49,20 @@ struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
if (stride > 1) { if (stride > 1) {
branch1 = torch::nn::Sequential( branch1 = torch::nn::Sequential(
conv33(inp, inp, stride), conv33(inp, inp, stride),
torch::nn::BatchNorm(inp), torch::nn::BatchNorm2d(inp),
conv11(inp, branch_features), conv11(inp, branch_features),
torch::nn::BatchNorm(branch_features), torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_)); torch::nn::Functional(modelsimpl::relu_));
} }
branch2 = torch::nn::Sequential( branch2 = torch::nn::Sequential(
conv11(stride > 1 ? inp : branch_features, branch_features), conv11(stride > 1 ? inp : branch_features, branch_features),
torch::nn::BatchNorm(branch_features), torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_), torch::nn::Functional(modelsimpl::relu_),
conv33(branch_features, branch_features, stride), conv33(branch_features, branch_features, stride),
torch::nn::BatchNorm(branch_features), torch::nn::BatchNorm2d(branch_features),
conv11(branch_features, branch_features), conv11(branch_features, branch_features),
torch::nn::BatchNorm(branch_features), torch::nn::BatchNorm2d(branch_features),
torch::nn::Functional(modelsimpl::relu_)); torch::nn::Functional(modelsimpl::relu_));
if (!branch1.is_empty()) if (!branch1.is_empty())
...@@ -108,7 +108,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( ...@@ -108,7 +108,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
.stride(2) .stride(2)
.padding(1) .padding(1)
.bias(false)), .bias(false)),
torch::nn::BatchNorm(output_channels), torch::nn::BatchNorm2d(output_channels),
torch::nn::Functional(modelsimpl::relu_)); torch::nn::Functional(modelsimpl::relu_));
input_channels = output_channels; input_channels = output_channels;
...@@ -135,7 +135,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl( ...@@ -135,7 +135,7 @@ ShuffleNetV2Impl::ShuffleNetV2Impl(
.stride(1) .stride(1)
.padding(0) .padding(0)
.bias(false)), .bias(false)),
torch::nn::BatchNorm(output_channels), torch::nn::BatchNorm2d(output_channels),
torch::nn::Functional(modelsimpl::relu_)); torch::nn::Functional(modelsimpl::relu_));
fc = torch::nn::Linear(output_channels, num_classes); fc = torch::nn::Linear(output_channels, num_classes);
......
...@@ -19,7 +19,7 @@ torch::nn::Sequential makeLayers( ...@@ -19,7 +19,7 @@ torch::nn::Sequential makeLayers(
torch::nn::Conv2dOptions(channels, V, 3).padding(1))); torch::nn::Conv2dOptions(channels, V, 3).padding(1)));
if (batch_norm) if (batch_norm)
seq->push_back(torch::nn::BatchNorm(V)); seq->push_back(torch::nn::BatchNorm2d(V));
seq->push_back(torch::nn::Functional(modelsimpl::relu_)); seq->push_back(torch::nn::Functional(modelsimpl::relu_));
channels = V; channels = V;
...@@ -35,10 +35,10 @@ void VGGImpl::_initialize_weights() { ...@@ -35,10 +35,10 @@ void VGGImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight,
/*a=*/0, /*a=*/0,
torch::nn::init::FanMode::FanOut, torch::kFanOut,
torch::nn::init::Nonlinearity::ReLU); torch::kReLU);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment