Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
4ff13616
Unverified
Commit
4ff13616
authored
Apr 12, 2023
by
chriscarving
Committed by
GitHub
Apr 12, 2023
Browse files
[Fix] Update pre-commit-config-zh-cn.yaml and add typehints for PointNet2SAMSG (#2396)
parent
9f61effd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
29 deletions
+43
-29
.pre-commit-config-zh-cn.yaml
.pre-commit-config-zh-cn.yaml
+4
-10
mmdet3d/models/backbones/pointnet2_sa_msg.py
mmdet3d/models/backbones/pointnet2_sa_msg.py
+36
-18
mmdet3d/models/layers/pointnet_modules/builder.py
mmdet3d/models/layers/pointnet_modules/builder.py
+3
-1
No files found.
.pre-commit-config-zh-cn.yaml
View file @
4ff13616
exclude
:
^tests/data/
repos
:
repos
:
-
repo
:
https://gitee.com/openmmlab/mirrors-flake8
-
repo
:
https://gitee.com/openmmlab/mirrors-flake8
rev
:
5.0.4
rev
:
5.0.4
...
@@ -25,6 +24,10 @@ repos:
...
@@ -25,6 +24,10 @@ repos:
args
:
[
"
--remove"
]
args
:
[
"
--remove"
]
-
id
:
mixed-line-ending
-
id
:
mixed-line-ending
args
:
[
"
--fix=lf"
]
args
:
[
"
--fix=lf"
]
-
repo
:
https://gitee.com/openmmlab/mirrors-codespell
rev
:
v2.2.1
hooks
:
-
id
:
codespell
-
repo
:
https://gitee.com/openmmlab/mirrors-mdformat
-
repo
:
https://gitee.com/openmmlab/mirrors-mdformat
rev
:
0.7.9
rev
:
0.7.9
hooks
:
hooks
:
...
@@ -34,20 +37,11 @@ repos:
...
@@ -34,20 +37,11 @@ repos:
-
mdformat-openmmlab
-
mdformat-openmmlab
-
mdformat_frontmatter
-
mdformat_frontmatter
-
linkify-it-py
-
linkify-it-py
-
repo
:
https://gitee.com/openmmlab/mirrors-codespell
rev
:
v2.2.1
hooks
:
-
id
:
codespell
-
repo
:
https://gitee.com/openmmlab/mirrors-docformatter
-
repo
:
https://gitee.com/openmmlab/mirrors-docformatter
rev
:
v1.3.1
rev
:
v1.3.1
hooks
:
hooks
:
-
id
:
docformatter
-
id
:
docformatter
args
:
[
"
--in-place"
,
"
--wrap-descriptions"
,
"
79"
]
args
:
[
"
--in-place"
,
"
--wrap-descriptions"
,
"
79"
]
-
repo
:
https://gitee.com/openmmlab/mirrors-pyupgrade
rev
:
v3.0.0
hooks
:
-
id
:
pyupgrade
args
:
[
"
--py36-plus"
]
-
repo
:
https://gitee.com/openmmlab/pre-commit-hooks
-
repo
:
https://gitee.com/openmmlab/pre-commit-hooks
rev
:
v0.2.0
rev
:
v0.2.0
hooks
:
hooks
:
...
...
mmdet3d/models/backbones/pointnet2_sa_msg.py
View file @
4ff13616
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Tuple
import
torch
import
torch
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
mmdet3d.models.layers.pointnet_modules
import
build_sa_module
from
mmdet3d.models.layers.pointnet_modules
import
build_sa_module
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
from
mmdet3d.utils
import
OptConfigType
from
.base_pointnet
import
BasePointNet
from
.base_pointnet
import
BasePointNet
ThreeTupleIntType
=
Tuple
[
Tuple
[
Tuple
[
int
,
int
,
int
]]]
TwoTupleIntType
=
Tuple
[
Tuple
[
int
,
int
,
int
]]
TwoTupleStrType
=
Tuple
[
Tuple
[
str
]]
@
MODELS
.
register_module
()
@
MODELS
.
register_module
()
class
PointNet2SAMSG
(
BasePointNet
):
class
PointNet2SAMSG
(
BasePointNet
):
...
@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
...
@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation
aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features.
multi-scale grouping features.
fps_mods
(tuple[int])
: Mod of FPS for each SA module.
fps_mods
Sequence[Tuple[str]]
: Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for
dilated_group (tuple[bool]): Whether to use dilated ball query for
...
@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
...
@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
:
int
,
num_points
=
(
2048
,
1024
,
512
,
256
),
num_points
:
Tuple
[
int
]
=
(
2048
,
1024
,
512
,
256
),
radii
=
((
0.2
,
0.4
,
0.8
),
(
0.4
,
0.8
,
1.6
),
(
1.6
,
3.2
,
4.8
)),
radii
:
Tuple
[
Tuple
[
float
,
float
,
float
]]
=
(
num_samples
=
((
32
,
32
,
64
),
(
32
,
32
,
64
),
(
32
,
32
,
32
)),
(
0.2
,
0.4
,
0.8
),
sa_channels
=
(((
16
,
16
,
32
),
(
16
,
16
,
32
),
(
32
,
32
,
64
)),
(
0.4
,
0.8
,
1.6
),
((
64
,
64
,
128
),
(
64
,
64
,
128
),
(
64
,
96
,
128
)),
(
1.6
,
3.2
,
4.8
),
((
128
,
128
,
256
),
(
128
,
192
,
256
),
(
128
,
256
,
),
num_samples
:
TwoTupleIntType
=
((
32
,
32
,
64
),
(
32
,
32
,
64
),
(
32
,
32
,
32
)),
sa_channels
:
ThreeTupleIntType
=
(((
16
,
16
,
32
),
(
16
,
16
,
32
),
(
32
,
32
,
64
)),
((
64
,
64
,
128
),
(
64
,
64
,
128
),
(
64
,
96
,
128
)),
((
128
,
128
,
256
),
(
128
,
192
,
256
),
(
128
,
256
,
256
))),
256
))),
aggregation_channels
=
(
64
,
128
,
256
),
aggregation_channels
:
Tuple
[
int
]
=
(
64
,
128
,
256
),
fps_mods
=
((
'D-FPS'
),
(
'FS'
),
(
'F-FPS'
,
'D-FPS'
)),
fps_mods
:
TwoTupleStrType
=
((
'D-FPS'
),
(
'FS'
),
(
'F-FPS'
,
fps_sample_range_lists
=
((
-
1
),
(
-
1
),
(
512
,
-
1
)),
'D-FPS'
)),
dilated_group
=
(
True
,
True
,
True
),
fps_sample_range_lists
:
TwoTupleIntType
=
((
-
1
),
(
-
1
),
(
512
,
out_indices
=
(
2
,
),
-
1
)),
norm_cfg
=
dict
(
type
=
'BN2d'
),
dilated_group
:
Tuple
[
bool
]
=
(
True
,
True
,
True
),
sa_cfg
=
dict
(
out_indices
:
Tuple
[
int
]
=
(
2
,
),
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
),
sa_cfg
:
dict
=
dict
(
type
=
'PointSAModuleMSG'
,
type
=
'PointSAModuleMSG'
,
pool_mod
=
'max'
,
pool_mod
=
'max'
,
use_xyz
=
True
,
use_xyz
=
True
,
normalize_xyz
=
False
),
normalize_xyz
=
False
),
init_cfg
=
None
):
init_cfg
:
OptConfigType
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
num_sa
=
len
(
sa_channels
)
self
.
num_sa
=
len
(
sa_channels
)
self
.
out_indices
=
out_indices
self
.
out_indices
=
out_indices
...
@@ -123,7 +141,7 @@ class PointNet2SAMSG(BasePointNet):
...
@@ -123,7 +141,7 @@ class PointNet2SAMSG(BasePointNet):
bias
=
True
))
bias
=
True
))
sa_in_channel
=
cur_aggregation_channel
sa_in_channel
=
cur_aggregation_channel
def
forward
(
self
,
points
):
def
forward
(
self
,
points
:
torch
.
Tensor
):
"""Forward pass.
"""Forward pass.
Args:
Args:
...
...
mmdet3d/models/layers/pointnet_modules/builder.py
View file @
4ff13616
...
@@ -4,7 +4,9 @@ from typing import Union
...
@@ -4,7 +4,9 @@ from typing import Union
from
mmengine.registry
import
Registry
from
mmengine.registry
import
Registry
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
SA_MODULES
=
Registry
(
'point_sa_module'
)
SA_MODULES
=
Registry
(
name
=
'point_sa_module'
,
locations
=
[
'mmdet3d.models.layers.pointnet_modules'
])
def
build_sa_module
(
cfg
:
Union
[
dict
,
None
],
*
args
,
**
kwargs
)
->
nn
.
Module
:
def
build_sa_module
(
cfg
:
Union
[
dict
,
None
],
*
args
,
**
kwargs
)
->
nn
.
Module
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment