"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fbcd3ba6b27a2b019c7209aeb3073c41b72bff43"
Commit e22c105c authored by Will Feng's avatar Will Feng
Browse files

Change all torch::nn::init::Nonlinearity::{name} and...

Change all torch::nn::init::Nonlinearity::{name} and torch::nn::init::FanMode::{name} usage to torch::k{name}
parent 17e355f7
...@@ -111,8 +111,8 @@ void MNASNetImpl::_initialize_weights() { ...@@ -111,8 +111,8 @@ void MNASNetImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight,
0, 0,
torch::nn::init::FanMode::FanOut, torch::kFanOut,
torch::nn::init::Nonlinearity::ReLU); torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
......
...@@ -135,7 +135,7 @@ MobileNetV2Impl::MobileNetV2Impl( ...@@ -135,7 +135,7 @@ MobileNetV2Impl::MobileNetV2Impl(
for (auto& module : modules(/*include_self=*/false)) { for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) { if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut); M->weight, 0, torch::kFanOut);
if (M->options.with_bias()) if (M->options.with_bias())
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
......
...@@ -146,8 +146,8 @@ ResNetImpl<Block>::ResNetImpl( ...@@ -146,8 +146,8 @@ ResNetImpl<Block>::ResNetImpl(
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight,
/*a=*/0, /*a=*/0,
torch::nn::init::FanMode::FanOut, torch::kFanOut,
torch::nn::init::Nonlinearity::ReLU); torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
......
...@@ -35,8 +35,8 @@ void VGGImpl::_initialize_weights() { ...@@ -35,8 +35,8 @@ void VGGImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight,
/*a=*/0, /*a=*/0,
torch::nn::init::FanMode::FanOut, torch::kFanOut,
torch::nn::init::Nonlinearity::ReLU); torch::kReLU);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->weight, 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment