Commit ce79da2e authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'add-tta' into 'master'

Support test time augmentation

See merge request open-mmlab/mmdet.3d!70
parents f6e95edd 3c5ff9fa
......@@ -100,7 +100,7 @@ class SECONDFusionFPN(SECONDFPN):
coors=None,
points=None,
img_feats=None,
img_meta=None):
img_metas=None):
assert len(x) == len(self.in_channels)
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
......@@ -119,5 +119,5 @@ class SECONDFusionFPN(SECONDFPN):
coors[:, 3] / self.downsample_rates[2])
# fusion for each point
out = self.fusion_layer(img_feats, points, out,
downsample_pts_coors, img_meta)
downsample_pts_coors, img_metas)
return [out]
......@@ -51,7 +51,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
@abstractmethod
def forward_train(self,
x,
img_meta,
img_metas,
proposal_list,
gt_bboxes,
gt_labels,
......@@ -64,7 +64,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
def simple_test(self,
x,
proposal_list,
img_meta,
img_metas,
proposals=None,
rescale=False,
**kwargs):
......
......@@ -441,7 +441,7 @@ class PartA2BboxHead(nn.Module):
bbox_pred,
class_labels,
class_pred,
img_meta,
img_metas,
cfg=None):
roi_batch_id = rois[..., 0]
roi_boxes = rois[..., 1:] # boxes without batch id
......@@ -474,8 +474,8 @@ class PartA2BboxHead(nn.Module):
selected_scores = cur_cls_score[selected]
result_list.append(
(img_meta[batch_id]['box_type_3d'](selected_bboxes,
self.bbox_coder.code_size),
(img_metas[batch_id]['box_type_3d'](selected_bboxes,
self.bbox_coder.code_size),
selected_scores, selected_label_preds))
return result_list
......
......@@ -59,7 +59,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return hasattr(self,
'semantic_head') and self.semantic_head is not None
def forward_train(self, feats_dict, voxels_dict, img_meta, proposal_list,
def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
gt_bboxes_3d, gt_labels_3d):
"""Training forward function of PartAggregationROIHead
......@@ -97,7 +97,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return losses
def simple_test(self, feats_dict, voxels_dict, img_meta, proposal_list,
def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
**kwargs):
"""Simple testing forward function of PartAggregationROIHead
......@@ -131,7 +131,7 @@ class PartAggregationROIHead(Base3DRoIHead):
bbox_results['bbox_pred'],
labels_3d,
cls_preds,
img_meta,
img_metas,
cfg=self.test_cfg)
bbox_results = [
......
......@@ -188,7 +188,7 @@ class DynamicVFE(nn.Module):
coors,
points=None,
img_feats=None,
img_meta=None):
img_metas=None):
"""Forward functions
Args:
......@@ -198,7 +198,7 @@ class DynamicVFE(nn.Module):
multi-modality fusion. Defaults to None.
img_feats (list[torch.Tensor], optional): Image fetures used for
multi-modality fusion. Defaults to None.
img_meta (dict, optional): [description]. Defaults to None.
img_metas (dict, optional): [description]. Defaults to None.
Returns:
tuple: If `return_point_feats` is False, returns voxel features and
......@@ -237,7 +237,7 @@ class DynamicVFE(nn.Module):
if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None
and img_feats is not None):
point_feats = self.fusion_layer(img_feats, points, point_feats,
img_meta)
img_metas)
voxel_feats, voxel_coors = self.vfe_scatter(point_feats, coors)
if i != len(self.vfe_layers) - 1:
# need to concat voxel feats if it is not the last vfe
......@@ -351,7 +351,7 @@ class HardVFE(nn.Module):
num_points,
coors,
img_feats=None,
img_meta=None):
img_metas=None):
"""Forward functions
Args:
......@@ -360,7 +360,7 @@ class HardVFE(nn.Module):
coors (torch.Tensor): Coordinates of voxels, shape is Mx(1+NDim).
img_feats (list[torch.Tensor], optional): Image fetures used for
multi-modality fusion. Defaults to None.
img_meta (dict, optional): [description]. Defaults to None.
img_metas (dict, optional): [description]. Defaults to None.
Returns:
tuple: If `return_point_feats` is False, returns voxel features and
......@@ -410,12 +410,12 @@ class HardVFE(nn.Module):
if (self.fusion_layer is not None and img_feats is not None):
voxel_feats = self.fusion_with_mask(features, mask, voxel_feats,
coors, img_feats, img_meta)
coors, img_feats, img_metas)
return voxel_feats
def fusion_with_mask(self, features, mask, voxel_feats, coors, img_feats,
img_meta):
img_metas):
"""Fuse image and point features with mask.
Args:
......@@ -425,7 +425,7 @@ class HardVFE(nn.Module):
voxel_feats (torch.Tensor): Features of voxels.
coors (torch.Tensor): Coordinates of each single voxel.
img_feats (list[torch.Tensor]): Multi-scale feature maps of image.
img_meta (list(dict)): Meta information of image and points.
img_metas (list(dict)): Meta information of image and points.
Returns:
torch.Tensor: Fused features of each voxel.
......@@ -439,7 +439,7 @@ class HardVFE(nn.Module):
point_feats = voxel_feats[mask]
point_feats = self.fusion_layer(img_feats, points, point_feats,
img_meta)
img_metas)
voxel_canvas = voxel_feats.new_zeros(
size=(voxel_feats.size(0), voxel_feats.size(1),
......
......@@ -27,7 +27,7 @@ def test_getitem():
dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 36, 1 / 36],
scale_range=None),
......@@ -50,11 +50,11 @@ def test_getitem():
gt_labels = data['gt_labels_3d']._data
pts_semantic_mask = data['pts_semantic_mask']._data
pts_instance_mask = data['pts_instance_mask']._data
file_name = data['img_meta']._data['file_name']
flip_xz = data['img_meta']._data['flip_xz']
flip_yz = data['img_meta']._data['flip_yz']
rot_angle = data['img_meta']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx']
file_name = data['img_metas']._data['file_name']
flip_xz = data['img_metas']._data['flip_xz']
flip_yz = data['img_metas']._data['flip_yz']
rot_angle = data['img_metas']._data['rot_angle']
sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/scannet/' \
'points/scene0000_00.bin'
assert flip_xz is True
......
......@@ -19,7 +19,7 @@ def test_getitem():
dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 6, 1 / 6],
scale_range=[0.85, 1.15]),
......@@ -39,12 +39,12 @@ def test_getitem():
points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data
gt_labels_3d = data['gt_labels_3d']._data
file_name = data['img_meta']._data['file_name']
flip_xz = data['img_meta']._data['flip_xz']
flip_yz = data['img_meta']._data['flip_yz']
scale_ratio = data['img_meta']._data['scale_ratio']
rot_angle = data['img_meta']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx']
file_name = data['img_metas']._data['file_name']
flip_xz = data['img_metas']._data['flip_xz']
flip_yz = data['img_metas']._data['flip_yz']
scale_ratio = data['img_metas']._data['scale_ratio']
rot_angle = data['img_metas']._data['rot_angle']
sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/sunrgbd' \
'/points/000001.bin'
assert flip_xz is False
......
......@@ -71,7 +71,7 @@ def test_anchor3d_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_head_cfg(
'second/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py')
'second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
......@@ -123,7 +123,7 @@ def test_anchor3d_head_getboxes():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_head_cfg(
'second/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py')
'second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
......
......@@ -2,7 +2,8 @@ import numpy as np
import torch
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.datasets.pipelines import IndoorFlipData, IndoorGlobalRotScale
from mmdet3d.datasets.pipelines import (IndoorFlipData,
IndoorGlobalRotScaleTrans)
def test_indoor_flip_data():
......@@ -69,7 +70,7 @@ def test_indoor_flip_data():
def test_global_rot_scale():
np.random.seed(0)
sunrgbd_augment = IndoorGlobalRotScale(
sunrgbd_augment = IndoorGlobalRotScaleTrans(
True, rot_range=[-1 / 6, 1 / 6], scale_range=[0.85, 1.15])
sunrgbd_results = dict()
sunrgbd_results['points'] = np.array(
......@@ -100,7 +101,7 @@ def test_global_rot_scale():
expected_sunrgbd_gt_bboxes_3d, 1e-3)
np.random.seed(0)
scannet_augment = IndoorGlobalRotScale(
scannet_augment = IndoorGlobalRotScaleTrans(
True, rot_range=[-1 * 1 / 36, 1 / 36], scale_range=None)
scannet_results = dict()
scannet_results['points'] = np.array(
......
......@@ -30,7 +30,7 @@ def test_scannet_pipeline():
dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 36, 1 / 36],
scale_range=None),
......@@ -113,7 +113,7 @@ def test_sunrgbd_pipeline():
dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 6, 1 / 6],
scale_range=[0.85, 1.15]),
......
......@@ -16,14 +16,14 @@ def test_outdoor_aug_pipeline():
dict(
type='ObjectNoise',
num_try=100,
loc_noise_std=[1.0, 1.0, 0.5],
translation_std=[1.0, 1.0, 0.5],
global_rot_range=[0.0, 0.0],
rot_uniform_noise=[-0.78539816, 0.78539816]),
rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5),
dict(
type='GlobalRotScale',
rot_uniform_noise=[-0.78539816, 0.78539816],
scaling_uniform_noise=[0.95, 1.05]),
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
......@@ -98,7 +98,7 @@ def test_outdoor_aug_pipeline():
pts_filename='tests/data/kitti/a.bin',
ann_info=dict(gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d),
bbox3d_fields=[],
)
img_fields=[])
output = pipeline(results)
......@@ -133,10 +133,10 @@ def test_outdoor_velocity_aug_pipeline():
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(
type='GlobalRotScale',
rot_uniform_noise=[-0.3925, 0.3925],
scaling_uniform_noise=[0.95, 1.05],
trans_normal_noise=[0, 0, 0]),
type='GlobalRotScaleTrans',
rot_range=[-0.3925, 0.3925],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D', flip_ratio=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
......@@ -197,7 +197,7 @@ def test_outdoor_velocity_aug_pipeline():
pts_filename='tests/data/kitti/a.bin',
ann_info=dict(gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d),
bbox3d_fields=[],
)
img_fields=[])
output = pipeline(results)
......
......@@ -3,6 +3,7 @@ import os
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from tools.fuse_conv_bn import fuse_module
......@@ -13,37 +14,6 @@ from mmdet.apis import multi_gpu_test, set_random_seed, single_gpu_test
from mmdet.core import wrap_fp16_model
class MultipleKVAction(argparse.Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
def _parse_int_float_bool(self, val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
......@@ -82,7 +52,7 @@ def parse_args():
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options', nargs='+', action=MultipleKVAction, help='custom options')
'--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
......@@ -109,7 +79,7 @@ def main():
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config)
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
......@@ -167,7 +137,7 @@ def main():
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print('\nwriting results to {}'.format(args.out))
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.options is None else args.options
if args.format_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment