Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
511244e2
Commit
511244e2
authored
Nov 16, 2021
by
acivgin1
Browse files
better handling of weight loading for spconv1 vs 2
parent
667572fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
15 deletions
+6
-15
pcdet/__init__.py
pcdet/__init__.py
+0
-6
pcdet/models/detectors/detector3d_template.py
pcdet/models/detectors/detector3d_template.py
+6
-9
No files found.
pcdet/__init__.py
View file @
511244e2
import
subprocess
import
subprocess
from
pathlib
import
Path
from
pathlib
import
Path
from
packaging
import
version
as
p_version
from
.version
import
__version__
from
.version
import
__version__
__all__
=
[
__all__
=
[
...
@@ -24,7 +22,3 @@ script_version = get_git_commit_number()
...
@@ -24,7 +22,3 @@ script_version = get_git_commit_number()
if
script_version
not
in
__version__
:
if
script_version
not
in
__version__
:
__version__
=
__version__
+
'+py%s'
%
script_version
__version__
=
__version__
+
'+py%s'
%
script_version
def
v1_is_lower_than_v2
(
version1
:
str
,
version2
:
str
):
return
p_version
.
parse
(
version1
)
<
p_version
.
parse
(
version2
)
pcdet/models/detectors/detector3d_template.py
View file @
511244e2
...
@@ -3,7 +3,6 @@ import os
...
@@ -3,7 +3,6 @@ import os
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
...
import
v1_is_lower_than_v2
from
...ops.iou3d_nms
import
iou3d_nms_utils
from
...ops.iou3d_nms
import
iou3d_nms_utils
from
...spconv_utils
import
find_all_spconv_keys
from
...spconv_utils
import
find_all_spconv_keys
from
..
import
backbones_2d
,
backbones_3d
,
dense_heads
,
roi_heads
from
..
import
backbones_2d
,
backbones_3d
,
dense_heads
,
roi_heads
...
@@ -327,17 +326,16 @@ class Detector3DTemplate(nn.Module):
...
@@ -327,17 +326,16 @@ class Detector3DTemplate(nn.Module):
gt_iou
=
box_preds
.
new_zeros
(
box_preds
.
shape
[
0
])
gt_iou
=
box_preds
.
new_zeros
(
box_preds
.
shape
[
0
])
return
recall_dict
return
recall_dict
def
_load_state_dict
(
self
,
model_state_disk
,
version
,
*
,
strict
=
True
):
def
_load_state_dict
(
self
,
model_state_disk
,
*
,
strict
=
True
):
state_dict
=
self
.
state_dict
()
# local cache of state_dict
state_dict
=
self
.
state_dict
()
# local cache of state_dict
version
=
version
.
split
(
"+"
)[
1
]
spconv_keys
=
find_all_spconv_keys
(
self
)
spconv_keys
=
find_all_spconv_keys
(
self
)
update_model_state
=
{}
update_model_state
=
{}
for
key
,
val
in
model_state_disk
.
items
():
for
key
,
val
in
model_state_disk
.
items
():
if
version
is
None
or
v1_is_lower_than_v2
(
version
,
"0.4.0"
):
# spconv change
if
key
in
spconv_keys
and
key
in
state_dict
and
state_dict
[
key
].
shape
!=
val
.
shape
:
if
key
in
spconv_keys
:
# with different spconv versions, we need to adapt weight shapes for spconv blocks
val
=
val
.
transpose
(
-
1
,
-
2
).
contiguous
()
val
=
val
.
transpose
(
-
1
,
-
2
).
contiguous
()
if
key
in
state_dict
and
state_dict
[
key
].
shape
==
val
.
shape
:
if
key
in
state_dict
and
state_dict
[
key
].
shape
==
val
.
shape
:
update_model_state
[
key
]
=
val
update_model_state
[
key
]
=
val
...
@@ -363,7 +361,7 @@ class Detector3DTemplate(nn.Module):
...
@@ -363,7 +361,7 @@ class Detector3DTemplate(nn.Module):
if
version
is
not
None
:
if
version
is
not
None
:
logger
.
info
(
'==> Checkpoint trained from version: %s'
%
version
)
logger
.
info
(
'==> Checkpoint trained from version: %s'
%
version
)
state_dict
,
update_model_state
=
self
.
_load_state_dict
(
model_state_disk
,
version
,
strict
=
False
)
state_dict
,
update_model_state
=
self
.
_load_state_dict
(
model_state_disk
,
strict
=
False
)
for
key
in
state_dict
:
for
key
in
state_dict
:
if
key
not
in
update_model_state
:
if
key
not
in
update_model_state
:
...
@@ -381,8 +379,7 @@ class Detector3DTemplate(nn.Module):
...
@@ -381,8 +379,7 @@ class Detector3DTemplate(nn.Module):
epoch
=
checkpoint
.
get
(
'epoch'
,
-
1
)
epoch
=
checkpoint
.
get
(
'epoch'
,
-
1
)
it
=
checkpoint
.
get
(
'it'
,
0.0
)
it
=
checkpoint
.
get
(
'it'
,
0.0
)
version
=
checkpoint
.
get
(
"version"
,
None
)
self
.
_load_state_dict
(
checkpoint
[
'model_state'
],
strict
=
True
)
self
.
_load_state_dict
(
checkpoint
[
'model_state'
],
version
,
strict
=
True
)
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
if
'optimizer_state'
in
checkpoint
and
checkpoint
[
'optimizer_state'
]
is
not
None
:
if
'optimizer_state'
in
checkpoint
and
checkpoint
[
'optimizer_state'
]
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment