Commit 511244e2 authored by acivgin1's avatar acivgin1
Browse files

better handling of weight loading for spconv1 vs 2

parent 667572fd
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from packaging import version as p_version
from .version import __version__ from .version import __version__
__all__ = [ __all__ = [
...@@ -24,7 +22,3 @@ script_version = get_git_commit_number() ...@@ -24,7 +22,3 @@ script_version = get_git_commit_number()
if script_version not in __version__: if script_version not in __version__:
__version__ = __version__ + '+py%s' % script_version __version__ = __version__ + '+py%s' % script_version
def v1_is_lower_than_v2(version1: str, version2: str):
return p_version.parse(version1) < p_version.parse(version2)
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from ... import v1_is_lower_than_v2
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from ...spconv_utils import find_all_spconv_keys from ...spconv_utils import find_all_spconv_keys
from .. import backbones_2d, backbones_3d, dense_heads, roi_heads from .. import backbones_2d, backbones_3d, dense_heads, roi_heads
...@@ -327,17 +326,16 @@ class Detector3DTemplate(nn.Module): ...@@ -327,17 +326,16 @@ class Detector3DTemplate(nn.Module):
gt_iou = box_preds.new_zeros(box_preds.shape[0]) gt_iou = box_preds.new_zeros(box_preds.shape[0])
return recall_dict return recall_dict
def _load_state_dict(self, model_state_disk, version, *, strict=True): def _load_state_dict(self, model_state_disk, *, strict=True):
state_dict = self.state_dict() # local cache of state_dict state_dict = self.state_dict() # local cache of state_dict
version = version.split("+")[1]
spconv_keys = find_all_spconv_keys(self) spconv_keys = find_all_spconv_keys(self)
update_model_state = {} update_model_state = {}
for key, val in model_state_disk.items(): for key, val in model_state_disk.items():
if version is None or v1_is_lower_than_v2(version, "0.4.0"): # spconv change if key in spconv_keys and key in state_dict and state_dict[key].shape != val.shape:
if key in spconv_keys: # with different spconv versions, we need to adapt weight shapes for spconv blocks
val = val.transpose(-1, -2).contiguous() val = val.transpose(-1, -2).contiguous()
if key in state_dict and state_dict[key].shape == val.shape: if key in state_dict and state_dict[key].shape == val.shape:
update_model_state[key] = val update_model_state[key] = val
...@@ -363,7 +361,7 @@ class Detector3DTemplate(nn.Module): ...@@ -363,7 +361,7 @@ class Detector3DTemplate(nn.Module):
if version is not None: if version is not None:
logger.info('==> Checkpoint trained from version: %s' % version) logger.info('==> Checkpoint trained from version: %s' % version)
state_dict, update_model_state = self._load_state_dict(model_state_disk, version, strict=False) state_dict, update_model_state = self._load_state_dict(model_state_disk, strict=False)
for key in state_dict: for key in state_dict:
if key not in update_model_state: if key not in update_model_state:
...@@ -381,8 +379,7 @@ class Detector3DTemplate(nn.Module): ...@@ -381,8 +379,7 @@ class Detector3DTemplate(nn.Module):
epoch = checkpoint.get('epoch', -1) epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0) it = checkpoint.get('it', 0.0)
version = checkpoint.get("version", None) self._load_state_dict(checkpoint['model_state'], strict=True)
self._load_state_dict(checkpoint['model_state'], version, strict=True)
if optimizer is not None: if optimizer is not None:
if 'optimizer_state' in checkpoint and checkpoint['optimizer_state'] is not None: if 'optimizer_state' in checkpoint and checkpoint['optimizer_state'] is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment