Unverified Commit 601ce5fc authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix C++ linnt (#1971)

parent 12be107b
......@@ -36,7 +36,13 @@ at::Tensor ROIAlign_forward(
#endif
}
return ROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}
at::Tensor ROIAlign_backward(
......@@ -137,8 +143,13 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()};
return {grad_in,
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable()};
}
};
......
......@@ -107,10 +107,7 @@ void MNASNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(
M->weight,
0,
torch::kFanOut,
torch::kReLU);
M->weight, 0, torch::kFanOut, torch::kReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
......
......@@ -134,11 +134,11 @@ MobileNetV2Impl::MobileNetV2Impl(
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::kFanOut);
torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut);
if (M->options.bias())
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
} else if (
auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
......
......@@ -38,7 +38,8 @@ void VGGImpl::_initialize_weights() {
torch::kFanOut,
torch::kReLU);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
} else if (
auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment