Commit 8fc1a5d5 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: weights transformation between spconv 1.x and spconv 2.x, should...

bugfixed: weights transformation between spconv 1.x and spconv 2.x, should consider both native/implicit spconv 2.x
parent cddcf9ba
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from ...spconv_utils import find_all_spconv_keys from ...utils.spconv_utils import find_all_spconv_keys
from .. import backbones_2d, backbones_3d, dense_heads, roi_heads from .. import backbones_2d, backbones_3d, dense_heads, roi_heads
from ..backbones_2d import map_to_bev from ..backbones_2d import map_to_bev
from ..backbones_3d import pfe, vfe from ..backbones_3d import pfe, vfe
...@@ -335,7 +335,16 @@ class Detector3DTemplate(nn.Module): ...@@ -335,7 +335,16 @@ class Detector3DTemplate(nn.Module):
for key, val in model_state_disk.items(): for key, val in model_state_disk.items():
if key in spconv_keys and key in state_dict and state_dict[key].shape != val.shape: if key in spconv_keys and key in state_dict and state_dict[key].shape != val.shape:
# with different spconv versions, we need to adapt weight shapes for spconv blocks # with different spconv versions, we need to adapt weight shapes for spconv blocks
val = val.transpose(-1, -2).contiguous() # adapt spconv weights from version 1.x to version 2.x if you used weights from spconv 1.x
val_native = val.transpose(-1, -2) # (k1, k2, k3, c_in, c_out) to (k1, k2, k3, c_out, c_in)
if val_native.shape == state_dict[key].shape:
val = val_native.contiguous()
else:
assert val.shape.__len__() == 5, 'currently only spconv 3D is supported'
val_implicit = val.permute(4, 0, 1, 2, 3) # (k1, k2, k3, c_in, c_out) to (c_out, k1, k2, k3, c_in)
if val_implicit.shape == state_dict[key].shape:
val = val_implicit.contiguous()
if key in state_dict and state_dict[key].shape == val.shape: if key in state_dict and state_dict[key].shape == val.shape:
update_model_state[key] = val update_model_state[key] = val
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment