Commit 3817a135 authored by acivgin1's avatar acivgin1
Browse files

bump up version, support for older models

- as per the spconv documentation, the order of channels in weights is now changed
- we need to get all the affected weights, using the util function in spconv_utils
- we need to check the version with the version checker in __init__.py
- we need to transpose last two dimensions of existing weights, from older models
- to avoid duplication a single _load_state_dict method is added
parent 43b3c006
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from packaging import version
from .version import __version__ from .version import __version__
__all__ = [ __all__ = [
...@@ -22,3 +24,7 @@ script_version = get_git_commit_number() ...@@ -22,3 +24,7 @@ script_version = get_git_commit_number()
if script_version not in __version__: if script_version not in __version__:
__version__ = __version__ + '+py%s' % script_version __version__ = __version__ + '+py%s' % script_version
def v1_is_lower_than_v2(version1: str, version2: str):
return version.parse(version1) < version.parse(version2)
...@@ -3,7 +3,9 @@ import os ...@@ -3,7 +3,9 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from ... import v1_is_lower_than_v2
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from ...spconv_utils import find_all_spconv_keys
from .. import backbones_2d, backbones_3d, dense_heads, roi_heads from .. import backbones_2d, backbones_3d, dense_heads, roi_heads
from ..backbones_2d import map_to_bev from ..backbones_2d import map_to_bev
from ..backbones_3d import pfe, vfe from ..backbones_3d import pfe, vfe
...@@ -325,6 +327,28 @@ class Detector3DTemplate(nn.Module): ...@@ -325,6 +327,28 @@ class Detector3DTemplate(nn.Module):
gt_iou = box_preds.new_zeros(box_preds.shape[0]) gt_iou = box_preds.new_zeros(box_preds.shape[0])
return recall_dict return recall_dict
def _load_state_dict(self, model_state_disk, version, *, strict=True):
state_dict = self.state_dict() # local cache of state_dict
spconv_keys = find_all_spconv_keys(self)
update_model_state = {}
for key, val in model_state_disk.items():
if version is None or v1_is_lower_than_v2(version, "0.4.0"): # spconv change
if key in spconv_keys:
val = val.transpose(-1, -2).contiguous()
if key in state_dict and state_dict[key].shape == val.shape:
update_model_state[key] = val
# logger.info('Update weight %s: %s' % (key, str(val.shape)))
if strict:
self.load_state_dict(update_model_state)
else:
state_dict.update(update_model_state)
self.load_state_dict(state_dict)
return state_dict, update_model_state
def load_params_from_file(self, filename, logger, to_cpu=False): def load_params_from_file(self, filename, logger, to_cpu=False):
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise FileNotFoundError raise FileNotFoundError
...@@ -334,24 +358,17 @@ class Detector3DTemplate(nn.Module): ...@@ -334,24 +358,17 @@ class Detector3DTemplate(nn.Module):
checkpoint = torch.load(filename, map_location=loc_type) checkpoint = torch.load(filename, map_location=loc_type)
model_state_disk = checkpoint['model_state'] model_state_disk = checkpoint['model_state']
if 'version' in checkpoint: version = checkpoint.get("version", None)
logger.info('==> Checkpoint trained from version: %s' % checkpoint['version']) if version is not None:
logger.info('==> Checkpoint trained from version: %s' % version)
update_model_state = {}
for key, val in model_state_disk.items():
if key in self.state_dict() and self.state_dict()[key].shape == model_state_disk[key].shape:
update_model_state[key] = val
# logger.info('Update weight %s: %s' % (key, str(val.shape)))
state_dict = self.state_dict() state_dict, update_model_state = self._load_state_dict(model_state_disk, version, strict=False)
state_dict.update(update_model_state)
self.load_state_dict(state_dict)
for key in state_dict: for key in state_dict:
if key not in update_model_state: if key not in update_model_state:
logger.info('Not updated weight %s: %s' % (key, str(state_dict[key].shape))) logger.info('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))
logger.info('==> Done (loaded %d/%d)' % (len(update_model_state), len(self.state_dict()))) logger.info('==> Done (loaded %d/%d)' % (len(update_model_state), len(state_dict)))
def load_params_with_optimizer(self, filename, to_cpu=False, optimizer=None, logger=None): def load_params_with_optimizer(self, filename, to_cpu=False, optimizer=None, logger=None):
if not os.path.isfile(filename): if not os.path.isfile(filename):
...@@ -363,7 +380,8 @@ class Detector3DTemplate(nn.Module): ...@@ -363,7 +380,8 @@ class Detector3DTemplate(nn.Module):
epoch = checkpoint.get('epoch', -1) epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0) it = checkpoint.get('it', 0.0)
self.load_state_dict(checkpoint['model_state']) version = checkpoint.get("version", None)
self._load_state_dict(checkpoint['model_state'], version, strict=True)
if optimizer is not None: if optimizer is not None:
if 'optimizer_state' in checkpoint and checkpoint['optimizer_state'] is not None: if 'optimizer_state' in checkpoint and checkpoint['optimizer_state'] is not None:
......
from typing import Set
import spconv.pytorch as spconv
import torch.nn as nn
def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
"""
Finds all spconv keys that need to have weight's transposed
"""
found_keys: Set[str] = set()
for name, child in model.named_children():
new_prefix = f"{prefix}.{name}" if prefix != "" else name
if isinstance(child, (spconv.SubMConv3d, spconv.SparseConv3d, spconv.SparseInverseConv3d)):
new_prefix = f"{new_prefix}.weight"
found_keys.add(new_prefix)
found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))
return found_keys
...@@ -40,7 +40,7 @@ class PostInstallation(install): ...@@ -40,7 +40,7 @@ class PostInstallation(install):
if __name__ == '__main__': if __name__ == '__main__':
version = '0.3.0+%s' % get_git_commit_number() version = '0.4.0+%s' % get_git_commit_number()
write_version_to_file(version, 'pcdet/version.py') write_version_to_file(version, 'pcdet/version.py')
setup( setup(
...@@ -50,7 +50,7 @@ if __name__ == '__main__': ...@@ -50,7 +50,7 @@ if __name__ == '__main__':
install_requires=[ install_requires=[
'numpy', 'numpy',
'torch>=1.1', 'torch>=1.1',
'spconv', # 'spconv', # spconv has different names depending on the cuda version
'numba', 'numba',
'tensorboardX', 'tensorboardX',
'easydict', 'easydict',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment