Unverified Commit 3122ea16 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Respect `strict=False` when loading detection models (#5841)

* Convert weights only if `old_key` is in `state_dict`

* Fix linter
parent 92eb12d6
......@@ -317,6 +317,7 @@ class MaskRCNNHeads(nn.Sequential):
for type in ["weight", "bias"]:
old_key = f"{prefix}mask_fcn{i+1}.{type}"
new_key = f"{prefix}{i}.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
......
......@@ -45,6 +45,7 @@ def _v1_to_v2_weights(state_dict, prefix):
for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{2*i}.{type}"
new_key = f"{prefix}conv.{i}.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
......
......@@ -56,6 +56,7 @@ class RPNHead(nn.Module):
for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{type}"
new_key = f"{prefix}conv.0.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
......
......@@ -128,6 +128,7 @@ class FeaturePyramidNetwork(nn.Module):
for type in ["weight", "bias"]:
old_key = f"{prefix}{block}.{i}.{type}"
new_key = f"{prefix}{block}.{i}.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment