"...text-generation-inference.git" did not exist on "37b64a5c10a87f25b5de1c3c55f3ea965104f290"
Unverified Commit ef0ffb80 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Revert "Change all torch::nn::init::Nonlinearity::{name} and...

Revert "Change all torch::nn::init::Nonlinearity::{name} and torch::nn::init::FanMode::{name} to torch::k{name} (#1394)" (#1428)

This reverts commit 8c3cea7f.
parent 2060576e
......@@ -109,7 +109,10 @@ void MNASNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::kFanOut, torch::kReLU);
M->weight,
0,
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
......
......@@ -134,7 +134,8 @@ MobileNetV2Impl::MobileNetV2Impl(
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut);
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut);
if (M->options.with_bias())
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
......
......@@ -146,8 +146,8 @@ ResNetImpl<Block>::ResNetImpl(
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
......
......@@ -35,8 +35,8 @@ void VGGImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment