Unverified Commit 3122ea16 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Respect `strict=False` when loading detection models (#5841)

* Convert weights only if `old_key` is in `state_dict`

* Fix linter
parent 92eb12d6
...@@ -317,6 +317,7 @@ class MaskRCNNHeads(nn.Sequential): ...@@ -317,6 +317,7 @@ class MaskRCNNHeads(nn.Sequential):
for type in ["weight", "bias"]: for type in ["weight", "bias"]:
old_key = f"{prefix}mask_fcn{i+1}.{type}" old_key = f"{prefix}mask_fcn{i+1}.{type}"
new_key = f"{prefix}{i}.0.{type}" new_key = f"{prefix}{i}.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict( super()._load_from_state_dict(
......
...@@ -45,6 +45,7 @@ def _v1_to_v2_weights(state_dict, prefix): ...@@ -45,6 +45,7 @@ def _v1_to_v2_weights(state_dict, prefix):
for type in ["weight", "bias"]: for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{2*i}.{type}" old_key = f"{prefix}conv.{2*i}.{type}"
new_key = f"{prefix}conv.{i}.0.{type}" new_key = f"{prefix}conv.{i}.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
......
...@@ -56,6 +56,7 @@ class RPNHead(nn.Module): ...@@ -56,6 +56,7 @@ class RPNHead(nn.Module):
for type in ["weight", "bias"]: for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{type}" old_key = f"{prefix}conv.{type}"
new_key = f"{prefix}conv.0.0.{type}" new_key = f"{prefix}conv.0.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict( super()._load_from_state_dict(
......
...@@ -128,6 +128,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -128,6 +128,7 @@ class FeaturePyramidNetwork(nn.Module):
for type in ["weight", "bias"]: for type in ["weight", "bias"]:
old_key = f"{prefix}{block}.{i}.{type}" old_key = f"{prefix}{block}.{i}.{type}"
new_key = f"{prefix}{block}.{i}.0.{type}" new_key = f"{prefix}{block}.{i}.0.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict( super()._load_from_state_dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment