Commit 61763fa9 authored by Sergey Zagoruyko's avatar Sergey Zagoruyko Committed by Francisco Massa
Browse files

fix for loading models with num_batches_tracked in frozen bn (#1728)

parent be6dd472
......@@ -130,6 +130,16 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment