Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
4ff13616
Unverified
Commit
4ff13616
authored
Apr 12, 2023
by
chriscarving
Committed by
GitHub
Apr 12, 2023
Browse files
[Fix] Update pre-commit-config-zh-cn.yaml and add typehints for PointNet2SAMSG (#2396)
parent
9f61effd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
29 deletions
+43
-29
.pre-commit-config-zh-cn.yaml
.pre-commit-config-zh-cn.yaml
+4
-10
mmdet3d/models/backbones/pointnet2_sa_msg.py
mmdet3d/models/backbones/pointnet2_sa_msg.py
+36
-18
mmdet3d/models/layers/pointnet_modules/builder.py
mmdet3d/models/layers/pointnet_modules/builder.py
+3
-1
No files found.
.pre-commit-config-zh-cn.yaml
View file @
4ff13616
exclude
:
^tests/data/
repos
:
-
repo
:
https://gitee.com/openmmlab/mirrors-flake8
rev
:
5.0.4
...
...
@@ -25,6 +24,10 @@ repos:
args
:
[
"
--remove"
]
-
id
:
mixed-line-ending
args
:
[
"
--fix=lf"
]
-
repo
:
https://gitee.com/openmmlab/mirrors-codespell
rev
:
v2.2.1
hooks
:
-
id
:
codespell
-
repo
:
https://gitee.com/openmmlab/mirrors-mdformat
rev
:
0.7.9
hooks
:
...
...
@@ -34,20 +37,11 @@ repos:
-
mdformat-openmmlab
-
mdformat_frontmatter
-
linkify-it-py
-
repo
:
https://gitee.com/openmmlab/mirrors-codespell
rev
:
v2.2.1
hooks
:
-
id
:
codespell
-
repo
:
https://gitee.com/openmmlab/mirrors-docformatter
rev
:
v1.3.1
hooks
:
-
id
:
docformatter
args
:
[
"
--in-place"
,
"
--wrap-descriptions"
,
"
79"
]
-
repo
:
https://gitee.com/openmmlab/mirrors-pyupgrade
rev
:
v3.0.0
hooks
:
-
id
:
pyupgrade
args
:
[
"
--py36-plus"
]
-
repo
:
https://gitee.com/openmmlab/pre-commit-hooks
rev
:
v0.2.0
hooks
:
...
...
mmdet3d/models/backbones/pointnet2_sa_msg.py
View file @
4ff13616
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Tuple
import
torch
from
mmcv.cnn
import
ConvModule
from
torch
import
nn
as
nn
from
mmdet3d.models.layers.pointnet_modules
import
build_sa_module
from
mmdet3d.registry
import
MODELS
from
mmdet3d.utils
import
OptConfigType
from
.base_pointnet
import
BasePointNet
ThreeTupleIntType
=
Tuple
[
Tuple
[
Tuple
[
int
,
int
,
int
]]]
TwoTupleIntType
=
Tuple
[
Tuple
[
int
,
int
,
int
]]
TwoTupleStrType
=
Tuple
[
Tuple
[
str
]]
@
MODELS
.
register_module
()
class
PointNet2SAMSG
(
BasePointNet
):
...
...
@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features.
fps_mods
(tuple[int])
: Mod of FPS for each SA module.
fps_mods
Sequence[Tuple[str]]
: Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for
...
...
@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
"""
def
__init__
(
self
,
in_channels
,
num_points
=
(
2048
,
1024
,
512
,
256
),
radii
=
((
0.2
,
0.4
,
0.8
),
(
0.4
,
0.8
,
1.6
),
(
1.6
,
3.2
,
4.8
)),
num_samples
=
((
32
,
32
,
64
),
(
32
,
32
,
64
),
(
32
,
32
,
32
)),
sa_channels
=
(((
16
,
16
,
32
),
(
16
,
16
,
32
),
(
32
,
32
,
64
)),
((
64
,
64
,
128
),
(
64
,
64
,
128
),
(
64
,
96
,
128
)),
((
128
,
128
,
256
),
(
128
,
192
,
256
),
(
128
,
256
,
256
))),
aggregation_channels
=
(
64
,
128
,
256
),
fps_mods
=
((
'D-FPS'
),
(
'FS'
),
(
'F-FPS'
,
'D-FPS'
)),
fps_sample_range_lists
=
((
-
1
),
(
-
1
),
(
512
,
-
1
)),
dilated_group
=
(
True
,
True
,
True
),
out_indices
=
(
2
,
),
norm_cfg
=
dict
(
type
=
'BN2d'
),
sa_cfg
=
dict
(
in_channels
:
int
,
num_points
:
Tuple
[
int
]
=
(
2048
,
1024
,
512
,
256
),
radii
:
Tuple
[
Tuple
[
float
,
float
,
float
]]
=
(
(
0.2
,
0.4
,
0.8
),
(
0.4
,
0.8
,
1.6
),
(
1.6
,
3.2
,
4.8
),
),
num_samples
:
TwoTupleIntType
=
((
32
,
32
,
64
),
(
32
,
32
,
64
),
(
32
,
32
,
32
)),
sa_channels
:
ThreeTupleIntType
=
(((
16
,
16
,
32
),
(
16
,
16
,
32
),
(
32
,
32
,
64
)),
((
64
,
64
,
128
),
(
64
,
64
,
128
),
(
64
,
96
,
128
)),
((
128
,
128
,
256
),
(
128
,
192
,
256
),
(
128
,
256
,
256
))),
aggregation_channels
:
Tuple
[
int
]
=
(
64
,
128
,
256
),
fps_mods
:
TwoTupleStrType
=
((
'D-FPS'
),
(
'FS'
),
(
'F-FPS'
,
'D-FPS'
)),
fps_sample_range_lists
:
TwoTupleIntType
=
((
-
1
),
(
-
1
),
(
512
,
-
1
)),
dilated_group
:
Tuple
[
bool
]
=
(
True
,
True
,
True
),
out_indices
:
Tuple
[
int
]
=
(
2
,
),
norm_cfg
:
dict
=
dict
(
type
=
'BN2d'
),
sa_cfg
:
dict
=
dict
(
type
=
'PointSAModuleMSG'
,
pool_mod
=
'max'
,
use_xyz
=
True
,
normalize_xyz
=
False
),
init_cfg
=
None
):
init_cfg
:
OptConfigType
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
num_sa
=
len
(
sa_channels
)
self
.
out_indices
=
out_indices
...
...
@@ -123,7 +141,7 @@ class PointNet2SAMSG(BasePointNet):
bias
=
True
))
sa_in_channel
=
cur_aggregation_channel
def
forward
(
self
,
points
):
def
forward
(
self
,
points
:
torch
.
Tensor
):
"""Forward pass.
Args:
...
...
mmdet3d/models/layers/pointnet_modules/builder.py
View file @
4ff13616
...
...
@@ -4,7 +4,9 @@ from typing import Union
from
mmengine.registry
import
Registry
from
torch
import
nn
as
nn
SA_MODULES
=
Registry
(
'point_sa_module'
)
SA_MODULES
=
Registry
(
name
=
'point_sa_module'
,
locations
=
[
'mmdet3d.models.layers.pointnet_modules'
])
def
build_sa_module
(
cfg
:
Union
[
dict
,
None
],
*
args
,
**
kwargs
)
->
nn
.
Module
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment