"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8183d0f16ea28498ea936c0ed0fb338a8b3e523c"
Unverified Commit 6cd23071 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support FLOPs and Params calculation in MMDet3D (#736)

* add get_flops.py

* modify SSD to support flops calc

* update docs

* support mono3d models
parent 0ddbd0ee
...@@ -124,7 +124,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task ...@@ -124,7 +124,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task
# Model Complexity # Model Complexity
You can use `tools/analysis_tools/get_flops.py` in MMDetection, a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch), to compute the FLOPs and params of a given model. You can use `tools/analysis_tools/get_flops.py` in MMDetection3D, a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch), to compute the FLOPs and params of a given model.
```shell ```shell
python tools/analysis_tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}] python tools/analysis_tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
...@@ -134,9 +134,9 @@ You will get the results like this. ...@@ -134,9 +134,9 @@ You will get the results like this.
```text ```text
============================== ==============================
Input shape: (3, 1280, 800) Input shape: (40000, 4)
Flops: 239.32 GFLOPs Flops: 5.78 GFLOPs
Params: 37.74 M Params: 953.83 k
============================== ==============================
``` ```
...@@ -145,9 +145,9 @@ number is absolutely correct. You may well use the result for simple ...@@ -145,9 +145,9 @@ number is absolutely correct. You may well use the result for simple
comparisons, but double check it before you adopt it in technical reports or papers. comparisons, but double check it before you adopt it in technical reports or papers.
1. FLOPs are related to the input shape while parameters are not. The default 1. FLOPs are related to the input shape while parameters are not. The default
input shape is (1, 3, 1280, 800). input shape is (1, 40000, 4).
2. Some operators are not counted into FLOPs like GN and custom operators. Refer to [`mmcv.cnn.get_model_complexity_info()`](https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py) for details. 2. Some operators are not counted into FLOPs like GN and custom operators. Refer to [`mmcv.cnn.get_model_complexity_info()`](https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py) for details.
3. The FLOPs of two-stage detectors is dependent on the number of proposals. 3. We currently only support FLOPs calculation of single-stage models with single-modality input (point cloud or image). We will support two-stage and multi-modality models in the future.
   
......
...@@ -38,6 +38,19 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -38,6 +38,19 @@ class SingleStage3DDetector(Base3DDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
def forward_dummy(self, points):
"""Used for computing network flops.
See `mmdetection/tools/analysis_tools/get_flops.py`
"""
x = self.extract_feat(points)
try:
sample_mod = self.train_cfg.sample_mod
outs = self.bbox_head(x, sample_mod)
except AttributeError:
outs = self.bbox_head(x)
return outs
def extract_feat(self, points, img_metas=None): def extract_feat(self, points, img_metas=None):
"""Directly extract features from the backbone+neck. """Directly extract features from the backbone+neck.
......
...@@ -36,7 +36,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -36,7 +36,7 @@ class VoxelNet(SingleStage3DDetector):
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_metas): def extract_feat(self, points, img_metas=None):
"""Extract features from points.""" """Extract features from points."""
voxels, num_points, coors = self.voxelize(points) voxels, num_points, coors = self.voxelize(points)
voxel_features = self.voxel_encoder(voxels, num_points, coors) voxel_features = self.voxel_encoder(voxels, num_points, coors)
......
import argparse
import torch
from mmcv import Config, DictAction
from mmdet3d.models import build_model
try:
from mmcv.cnn import get_model_complexity_info
except ImportError:
raise ImportError('Please upgrade mmcv to >0.6.2')
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[40000, 4],
help='input point cloud size')
parser.add_argument(
'--modality',
type=str,
default='point',
choices=['point', 'image', 'multi'],
help='input data modality')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.modality == 'point':
assert len(args.shape) == 2, 'invalid input shape'
input_shape = tuple(args.shape)
elif args.modality == 'image':
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
elif args.modality == 'multi':
raise NotImplementedError(
'FLOPs counter is currently not supported for models with '
'multi-modality input')
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
if torch.cuda.is_available():
model.cuda()
model.eval()
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not supported for {}'.format(
model.__class__.__name__))
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment