Commit 511244e2 authored by acivgin1's avatar acivgin1
Browse files

better handling of weight loading for spconv1 vs 2

parent 667572fd
import subprocess
from pathlib import Path
from packaging import version as p_version
from .version import __version__
__all__ = [
......@@ -24,7 +22,3 @@ script_version = get_git_commit_number()
if script_version not in __version__:
__version__ = __version__ + '+py%s' % script_version
def v1_is_lower_than_v2(version1: str, version2: str):
return p_version.parse(version1) < p_version.parse(version2)
......@@ -3,7 +3,6 @@ import os
import torch
import torch.nn as nn
from ... import v1_is_lower_than_v2
from ...ops.iou3d_nms import iou3d_nms_utils
from ...spconv_utils import find_all_spconv_keys
from .. import backbones_2d, backbones_3d, dense_heads, roi_heads
......@@ -327,17 +326,16 @@ class Detector3DTemplate(nn.Module):
gt_iou = box_preds.new_zeros(box_preds.shape[0])
return recall_dict
def _load_state_dict(self, model_state_disk, version, *, strict=True):
def _load_state_dict(self, model_state_disk, *, strict=True):
state_dict = self.state_dict() # local cache of state_dict
version = version.split("+")[1]
spconv_keys = find_all_spconv_keys(self)
update_model_state = {}
for key, val in model_state_disk.items():
if version is None or v1_is_lower_than_v2(version, "0.4.0"): # spconv change
if key in spconv_keys:
val = val.transpose(-1, -2).contiguous()
if key in spconv_keys and key in state_dict and state_dict[key].shape != val.shape:
# with different spconv versions, we need to adapt weight shapes for spconv blocks
val = val.transpose(-1, -2).contiguous()
if key in state_dict and state_dict[key].shape == val.shape:
update_model_state[key] = val
......@@ -363,7 +361,7 @@ class Detector3DTemplate(nn.Module):
if version is not None:
logger.info('==> Checkpoint trained from version: %s' % version)
state_dict, update_model_state = self._load_state_dict(model_state_disk, version, strict=False)
state_dict, update_model_state = self._load_state_dict(model_state_disk, strict=False)
for key in state_dict:
if key not in update_model_state:
......@@ -381,8 +379,7 @@ class Detector3DTemplate(nn.Module):
epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0)
version = checkpoint.get("version", None)
self._load_state_dict(checkpoint['model_state'], version, strict=True)
self._load_state_dict(checkpoint['model_state'], strict=True)
if optimizer is not None:
if 'optimizer_state' in checkpoint and checkpoint['optimizer_state'] is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment