Unverified Commit 4ff13616 authored by chriscarving's avatar chriscarving Committed by GitHub
Browse files

[Fix] Update pre-commit-config-zh-cn.yaml and add typehints for PointNet2SAMSG (#2396)

parent 9f61effd
exclude: ^tests/data/
repos: repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8 - repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4 rev: 5.0.4
...@@ -25,6 +24,10 @@ repos: ...@@ -25,6 +24,10 @@ repos:
args: ["--remove"] args: ["--remove"]
- id: mixed-line-ending - id: mixed-line-ending
args: ["--fix=lf"] args: ["--fix=lf"]
- repo: https://gitee.com/openmmlab/mirrors-codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://gitee.com/openmmlab/mirrors-mdformat - repo: https://gitee.com/openmmlab/mirrors-mdformat
rev: 0.7.9 rev: 0.7.9
hooks: hooks:
...@@ -34,20 +37,11 @@ repos: ...@@ -34,20 +37,11 @@ repos:
- mdformat-openmmlab - mdformat-openmmlab
- mdformat_frontmatter - mdformat_frontmatter
- linkify-it-py - linkify-it-py
- repo: https://gitee.com/openmmlab/mirrors-codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://gitee.com/openmmlab/mirrors-docformatter - repo: https://gitee.com/openmmlab/mirrors-docformatter
rev: v1.3.1 rev: v1.3.1
hooks: hooks:
- id: docformatter - id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"] args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://gitee.com/openmmlab/mirrors-pyupgrade
rev: v3.0.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://gitee.com/openmmlab/pre-commit-hooks - repo: https://gitee.com/openmmlab/pre-commit-hooks
rev: v0.2.0 rev: v0.2.0
hooks: hooks:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.layers.pointnet_modules import build_sa_module from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.utils import OptConfigType
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
ThreeTupleIntType = Tuple[Tuple[Tuple[int, int, int]]]
TwoTupleIntType = Tuple[Tuple[int, int, int]]
TwoTupleStrType = Tuple[Tuple[str]]
@MODELS.register_module() @MODELS.register_module()
class PointNet2SAMSG(BasePointNet): class PointNet2SAMSG(BasePointNet):
...@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module. sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features. multi-scale grouping features.
fps_mods (tuple[int]): Mod of FPS for each SA module. fps_mods Sequence[Tuple[str]]: Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples. points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for dilated_group (tuple[bool]): Whether to use dilated ball query for
...@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet): ...@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
num_points=(2048, 1024, 512, 256), num_points: Tuple[int] = (2048, 1024, 512, 256),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)), radii: Tuple[Tuple[float, float, float]] = (
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)), (0.2, 0.4, 0.8),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)), (0.4, 0.8, 1.6),
((64, 64, 128), (64, 64, 128), (64, 96, 128)), (1.6, 3.2, 4.8),
((128, 128, 256), (128, 192, 256), (128, 256, ),
num_samples: TwoTupleIntType = ((32, 32, 64), (32, 32, 64),
(32, 32, 32)),
sa_channels: ThreeTupleIntType = (((16, 16, 32), (16, 16, 32),
(32, 32, 64)),
((64, 64, 128),
(64, 64, 128), (64, 96,
128)),
((128, 128, 256),
(128, 192, 256), (128, 256,
256))), 256))),
aggregation_channels=(64, 128, 256), aggregation_channels: Tuple[int] = (64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')), fps_mods: TwoTupleStrType = (('D-FPS'), ('FS'), ('F-FPS',
fps_sample_range_lists=((-1), (-1), (512, -1)), 'D-FPS')),
dilated_group=(True, True, True), fps_sample_range_lists: TwoTupleIntType = ((-1), (-1), (512,
out_indices=(2, ), -1)),
norm_cfg=dict(type='BN2d'), dilated_group: Tuple[bool] = (True, True, True),
sa_cfg=dict( out_indices: Tuple[int] = (2, ),
norm_cfg: dict = dict(type='BN2d'),
sa_cfg: dict = dict(
type='PointSAModuleMSG', type='PointSAModuleMSG',
pool_mod='max', pool_mod='max',
use_xyz=True, use_xyz=True,
normalize_xyz=False), normalize_xyz=False),
init_cfg=None): init_cfg: OptConfigType = None):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels) self.num_sa = len(sa_channels)
self.out_indices = out_indices self.out_indices = out_indices
...@@ -123,7 +141,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -123,7 +141,7 @@ class PointNet2SAMSG(BasePointNet):
bias=True)) bias=True))
sa_in_channel = cur_aggregation_channel sa_in_channel = cur_aggregation_channel
def forward(self, points): def forward(self, points: torch.Tensor):
"""Forward pass. """Forward pass.
Args: Args:
......
...@@ -4,7 +4,9 @@ from typing import Union ...@@ -4,7 +4,9 @@ from typing import Union
from mmengine.registry import Registry from mmengine.registry import Registry
from torch import nn as nn from torch import nn as nn
SA_MODULES = Registry('point_sa_module') SA_MODULES = Registry(
name='point_sa_module',
locations=['mmdet3d.models.layers.pointnet_modules'])
def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module: def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment