Unverified Commit 601ce5fc authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix C++ linnt (#1971)

parent 12be107b
...@@ -36,7 +36,13 @@ at::Tensor ROIAlign_forward( ...@@ -36,7 +36,13 @@ at::Tensor ROIAlign_forward(
#endif #endif
} }
return ROIAlign_forward_cpu( return ROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned); input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
} }
at::Tensor ROIAlign_backward( at::Tensor ROIAlign_backward(
...@@ -137,8 +143,13 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { ...@@ -137,8 +143,13 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
input_shape[3], input_shape[3],
ctx->saved_data["sampling_ratio"].toInt(), ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool()); ctx->saved_data["aligned"].toBool());
return { return {grad_in,
grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()}; Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable()};
} }
}; };
......
...@@ -107,10 +107,7 @@ void MNASNetImpl::_initialize_weights() { ...@@ -107,10 +107,7 @@ void MNASNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) { for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(
M->weight, M->weight, 0, torch::kFanOut, torch::kReLU);
0,
torch::kFanOut,
torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) { else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
......
...@@ -134,11 +134,11 @@ MobileNetV2Impl::MobileNetV2Impl( ...@@ -134,11 +134,11 @@ MobileNetV2Impl::MobileNetV2Impl(
for (auto& module : modules(/*include_self=*/false)) { for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) { if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_( torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut);
M->weight, 0, torch::kFanOut);
if (M->options.bias()) if (M->options.bias())
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) { } else if (
auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight); torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias); torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
......
...@@ -38,7 +38,8 @@ void VGGImpl::_initialize_weights() { ...@@ -38,7 +38,8 @@ void VGGImpl::_initialize_weights() {
torch::kFanOut, torch::kFanOut,
torch::kReLU); torch::kReLU);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) { } else if (
auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1); torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0); torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) { } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment