Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
0e157c31
Unverified
Commit
0e157c31
authored
Aug 23, 2022
by
VVsssssk
Committed by
GitHub
Aug 23, 2022
Browse files
[Fix]: fix mmcv.model to mmengine.model (#1750)
parent
009d5d6e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
42 deletions
+25
-42
mmdet3d/models/layers/norm.py
mmdet3d/models/layers/norm.py
+3
-3
mmdet3d/models/layers/spconv/__init__.py
mmdet3d/models/layers/spconv/__init__.py
+1
-1
mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py
...3d/models/layers/spconv/overwrite_spconv/write_spconv2.py
+17
-34
mmdet3d/models/layers/transformer.py
mmdet3d/models/layers/transformer.py
+4
-4
No files found.
mmdet3d/models/layers/norm.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.cnn
import
NORM_LAYERS
from
mmcv.runner
import
force_fp32
from
mmengine.registry
import
MODELS
from
torch
import
distributed
as
dist
from
torch
import
nn
as
nn
from
torch.autograd.function
import
Function
...
...
@@ -25,7 +25,7 @@ class AllReduce(Function):
return
grad_output
@
NORM_LAYER
S
.
register_module
(
'naiveSyncBN1d'
)
@
MODEL
S
.
register_module
(
'naiveSyncBN1d'
)
class
NaiveSyncBatchNorm1d
(
nn
.
BatchNorm1d
):
"""Synchronized Batch Normalization for 3D Tensors.
...
...
@@ -98,7 +98,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
return
output
@
NORM_LAYER
S
.
register_module
(
'naiveSyncBN2d'
)
@
MODEL
S
.
register_module
(
'naiveSyncBN2d'
)
class
NaiveSyncBatchNorm2d
(
nn
.
BatchNorm2d
):
"""Synchronized Batch Normalization for 4D Tensors.
...
...
mmdet3d/models/layers/spconv/__init__.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
from
.
compat
_spconv
2
import
register_spconv2
from
.
overwrite
_spconv
import
register_spconv2
try
:
import
spconv
...
...
mmdet3d/models/layers/spconv/
compat
_spconv2.py
→
mmdet3d/models/layers/spconv/
overwrite_spconv/write
_spconv2.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
import
itertools
from
mm
cv.cnn.bricks
.registry
import
CONV_LAYER
S
from
mm
engine
.registry
import
MODEL
S
from
torch.nn.parameter
import
Parameter
...
...
@@ -17,47 +17,28 @@ def register_spconv2():
except
ImportError
:
return
False
else
:
CONV_LAYER
S
.
_register_module
(
SparseConv2d
,
'SparseConv2d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
SparseConv3d
,
'SparseConv3d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
SparseConv4d
,
'SparseConv4d'
,
force
=
True
)
MODEL
S
.
_register_module
(
SparseConv2d
,
'SparseConv2d'
,
force
=
True
)
MODEL
S
.
_register_module
(
SparseConv3d
,
'SparseConv3d'
,
force
=
True
)
MODEL
S
.
_register_module
(
SparseConv4d
,
'SparseConv4d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseConvTranspose2d
,
'SparseConvTranspose2d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseConvTranspose3d
,
'SparseConvTranspose3d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseInverseConv2d
,
'SparseInverseConv2d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseInverseConv3d
,
'SparseInverseConv3d'
,
force
=
True
)
CONV_LAYERS
.
_register_module
(
SubMConv2d
,
'SubMConv2d'
,
force
=
True
)
CONV_LAYERS
.
_register_module
(
SubMConv3d
,
'SubMConv3d'
,
force
=
True
)
CONV_LAYERS
.
_register_module
(
SubMConv4d
,
'SubMConv4d'
,
force
=
True
)
MODELS
.
_register_module
(
SubMConv2d
,
'SubMConv2d'
,
force
=
True
)
MODELS
.
_register_module
(
SubMConv3d
,
'SubMConv3d'
,
force
=
True
)
MODELS
.
_register_module
(
SubMConv4d
,
'SubMConv4d'
,
force
=
True
)
SparseModule
.
_version
=
2
SparseModule
.
_load_from_state_dict
=
_load_from_state_dict
SparseModule
.
_save_to_state_dict
=
_save_to_state_dict
return
True
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
param
=
param
if
keep_vars
else
param
.
detach
()
if
name
==
'weight'
:
dims
=
list
(
range
(
1
,
len
(
param
.
shape
)))
+
[
0
]
param
=
param
.
permute
(
*
dims
)
destination
[
prefix
+
name
]
=
param
for
name
,
buf
in
self
.
_buffers
.
items
():
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
"""Rewrite this func to compat the convolutional kernel weights between
...
...
@@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
version
=
local_metadata
.
get
(
'version'
,
None
)
for
hook
in
self
.
_load_state_dict_pre_hooks
.
values
():
hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
...
...
@@ -83,6 +65,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# 0.3.* to version 0.4+
if
len
(
param
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
input_param
=
input_param
[
0
]
if
version
!=
2
:
dims
=
[
len
(
input_param
.
shape
)
-
1
]
+
list
(
range
(
len
(
input_param
.
shape
)
-
1
))
input_param
=
input_param
.
permute
(
*
dims
)
...
...
mmdet3d/models/layers/transformer.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn.bricks.
registry
import
ATTENTION
from
mm
cv.cnn.bricks.transformer
import
POSITIONAL_ENCODING
,
MultiheadAttention
from
mmcv.cnn.bricks.
transformer
import
MultiheadAttention
from
mm
engine.registry
import
MODELS
from
torch
import
nn
as
nn
@
ATTENTION
.
register_module
()
@
MODELS
.
register_module
()
class
GroupFree3DMHA
(
MultiheadAttention
):
"""A warpper for torch.nn.MultiheadAttention for GroupFree3D.
...
...
@@ -108,7 +108,7 @@ class GroupFree3DMHA(MultiheadAttention):
**
kwargs
)
@
POSITIONAL_ENCODING
.
register_module
()
@
MODELS
.
register_module
()
class
ConvBNPositionalEncoding
(
nn
.
Module
):
"""Absolute position embedding with Conv learning.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment