Unverified Commit 3122ea16 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Respect `strict=False` when loading detection models (#5841)

* Convert weights only if `old_key` is in `state_dict`

* Fix linter
parent 92eb12d6
......@@ -317,7 +317,8 @@ class MaskRCNNHeads(nn.Sequential):
for type in ["weight", "bias"]:
old_key = f"{prefix}mask_fcn{i+1}.{type}"
new_key = f"{prefix}{i}.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
......
......@@ -45,7 +45,8 @@ def _v1_to_v2_weights(state_dict, prefix):
for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{2*i}.{type}"
new_key = f"{prefix}conv.{i}.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
def _default_anchorgen():
......
......@@ -56,7 +56,8 @@ class RPNHead(nn.Module):
for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{type}"
new_key = f"{prefix}conv.0.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
......
......@@ -128,7 +128,8 @@ class FeaturePyramidNetwork(nn.Module):
for type in ["weight", "bias"]:
old_key = f"{prefix}{block}.{i}.{type}"
new_key = f"{prefix}{block}.{i}.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment