Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
0e157c31
Unverified
Commit
0e157c31
authored
Aug 23, 2022
by
VVsssssk
Committed by
GitHub
Aug 23, 2022
Browse files
[Fix]: fix mmcv.model to mmengine.model (#1750)
parent
009d5d6e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
42 deletions
+25
-42
mmdet3d/models/layers/norm.py
mmdet3d/models/layers/norm.py
+3
-3
mmdet3d/models/layers/spconv/__init__.py
mmdet3d/models/layers/spconv/__init__.py
+1
-1
mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py
...3d/models/layers/spconv/overwrite_spconv/write_spconv2.py
+17
-34
mmdet3d/models/layers/transformer.py
mmdet3d/models/layers/transformer.py
+4
-4
No files found.
mmdet3d/models/layers/norm.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.cnn
import
NORM_LAYERS
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
mmengine.registry
import
MODELS
from
torch
import
distributed
as
dist
from
torch
import
distributed
as
dist
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
torch.autograd.function
import
Function
from
torch.autograd.function
import
Function
...
@@ -25,7 +25,7 @@ class AllReduce(Function):
...
@@ -25,7 +25,7 @@ class AllReduce(Function):
return
grad_output
return
grad_output
@
NORM_LAYER
S
.
register_module
(
'naiveSyncBN1d'
)
@
MODEL
S
.
register_module
(
'naiveSyncBN1d'
)
class
NaiveSyncBatchNorm1d
(
nn
.
BatchNorm1d
):
class
NaiveSyncBatchNorm1d
(
nn
.
BatchNorm1d
):
"""Synchronized Batch Normalization for 3D Tensors.
"""Synchronized Batch Normalization for 3D Tensors.
...
@@ -98,7 +98,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
...
@@ -98,7 +98,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
return
output
return
output
@
NORM_LAYER
S
.
register_module
(
'naiveSyncBN2d'
)
@
MODEL
S
.
register_module
(
'naiveSyncBN2d'
)
class
NaiveSyncBatchNorm2d
(
nn
.
BatchNorm2d
):
class
NaiveSyncBatchNorm2d
(
nn
.
BatchNorm2d
):
"""Synchronized Batch Normalization for 4D Tensors.
"""Synchronized Batch Normalization for 4D Tensors.
...
...
mmdet3d/models/layers/spconv/__init__.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.
compat
_spconv
2
import
register_spconv2
from
.
overwrite
_spconv
import
register_spconv2
try
:
try
:
import
spconv
import
spconv
...
...
mmdet3d/models/layers/spconv/
compat
_spconv2.py
→
mmdet3d/models/layers/spconv/
overwrite_spconv/write
_spconv2.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
itertools
import
itertools
from
mm
cv.cnn.bricks
.registry
import
CONV_LAYER
S
from
mm
engine
.registry
import
MODEL
S
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -17,47 +17,28 @@ def register_spconv2():
...
@@ -17,47 +17,28 @@ def register_spconv2():
except
ImportError
:
except
ImportError
:
return
False
return
False
else
:
else
:
CONV_LAYER
S
.
_register_module
(
SparseConv2d
,
'SparseConv2d'
,
force
=
True
)
MODEL
S
.
_register_module
(
SparseConv2d
,
'SparseConv2d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
SparseConv3d
,
'SparseConv3d'
,
force
=
True
)
MODEL
S
.
_register_module
(
SparseConv3d
,
'SparseConv3d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
SparseConv4d
,
'SparseConv4d'
,
force
=
True
)
MODEL
S
.
_register_module
(
SparseConv4d
,
'SparseConv4d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseConvTranspose2d
,
'SparseConvTranspose2d'
,
force
=
True
)
SparseConvTranspose2d
,
'SparseConvTranspose2d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseConvTranspose3d
,
'SparseConvTranspose3d'
,
force
=
True
)
SparseConvTranspose3d
,
'SparseConvTranspose3d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseInverseConv2d
,
'SparseInverseConv2d'
,
force
=
True
)
SparseInverseConv2d
,
'SparseInverseConv2d'
,
force
=
True
)
CONV_LAYER
S
.
_register_module
(
MODEL
S
.
_register_module
(
SparseInverseConv3d
,
'SparseInverseConv3d'
,
force
=
True
)
SparseInverseConv3d
,
'SparseInverseConv3d'
,
force
=
True
)
CONV_LAYERS
.
_register_module
(
SubMConv2d
,
'SubMConv2d'
,
force
=
True
)
MODELS
.
_register_module
(
SubMConv2d
,
'SubMConv2d'
,
force
=
True
)
CONV_LAYERS
.
_register_module
(
SubMConv3d
,
'SubMConv3d'
,
force
=
True
)
MODELS
.
_register_module
(
SubMConv3d
,
'SubMConv3d'
,
force
=
True
)
CONV_LAYERS
.
_register_module
(
SubMConv4d
,
'SubMConv4d'
,
force
=
True
)
MODELS
.
_register_module
(
SubMConv4d
,
'SubMConv4d'
,
force
=
True
)
SparseModule
.
_version
=
2
SparseModule
.
_load_from_state_dict
=
_load_from_state_dict
SparseModule
.
_load_from_state_dict
=
_load_from_state_dict
SparseModule
.
_save_to_state_dict
=
_save_to_state_dict
return
True
return
True
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
param
=
param
if
keep_vars
else
param
.
detach
()
if
name
==
'weight'
:
dims
=
list
(
range
(
1
,
len
(
param
.
shape
)))
+
[
0
]
param
=
param
.
permute
(
*
dims
)
destination
[
prefix
+
name
]
=
param
for
name
,
buf
in
self
.
_buffers
.
items
():
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
missing_keys
,
unexpected_keys
,
error_msgs
):
"""Rewrite this func to compat the convolutional kernel weights between
"""Rewrite this func to compat the convolutional kernel weights between
...
@@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
...
@@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
"""
version
=
local_metadata
.
get
(
'version'
,
None
)
for
hook
in
self
.
_load_state_dict_pre_hooks
.
values
():
for
hook
in
self
.
_load_state_dict_pre_hooks
.
values
():
hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
unexpected_keys
,
error_msgs
)
...
@@ -83,9 +65,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
...
@@ -83,9 +65,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# 0.3.* to version 0.4+
# 0.3.* to version 0.4+
if
len
(
param
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
if
len
(
param
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
input_param
=
input_param
[
0
]
input_param
=
input_param
[
0
]
dims
=
[
len
(
input_param
.
shape
)
-
1
]
+
list
(
if
version
!=
2
:
range
(
len
(
input_param
.
shape
)
-
1
))
dims
=
[
len
(
input_param
.
shape
)
-
1
]
+
list
(
input_param
=
input_param
.
permute
(
*
dims
)
range
(
len
(
input_param
.
shape
)
-
1
))
input_param
=
input_param
.
permute
(
*
dims
)
if
input_param
.
shape
!=
param
.
shape
:
if
input_param
.
shape
!=
param
.
shape
:
# local shape should match the one in checkpoint
# local shape should match the one in checkpoint
error_msgs
.
append
(
error_msgs
.
append
(
...
...
mmdet3d/models/layers/transformer.py
View file @
0e157c31
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn.bricks.
registry
import
ATTENTION
from
mmcv.cnn.bricks.
transformer
import
MultiheadAttention
from
mm
cv.cnn.bricks.transformer
import
POSITIONAL_ENCODING
,
MultiheadAttention
from
mm
engine.registry
import
MODELS
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
@
ATTENTION
.
register_module
()
@
MODELS
.
register_module
()
class
GroupFree3DMHA
(
MultiheadAttention
):
class
GroupFree3DMHA
(
MultiheadAttention
):
"""A warpper for torch.nn.MultiheadAttention for GroupFree3D.
"""A warpper for torch.nn.MultiheadAttention for GroupFree3D.
...
@@ -108,7 +108,7 @@ class GroupFree3DMHA(MultiheadAttention):
...
@@ -108,7 +108,7 @@ class GroupFree3DMHA(MultiheadAttention):
**
kwargs
)
**
kwargs
)
@
POSITIONAL_ENCODING
.
register_module
()
@
MODELS
.
register_module
()
class
ConvBNPositionalEncoding
(
nn
.
Module
):
class
ConvBNPositionalEncoding
(
nn
.
Module
):
"""Absolute position embedding with Conv learning.
"""Absolute position embedding with Conv learning.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment