Commit 3c5ff9fa authored by zhangwenwei's avatar zhangwenwei
Browse files

Support test time augmentation

parent f6e95edd
......@@ -100,7 +100,7 @@ class SECONDFusionFPN(SECONDFPN):
coors=None,
points=None,
img_feats=None,
img_meta=None):
img_metas=None):
assert len(x) == len(self.in_channels)
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
......@@ -119,5 +119,5 @@ class SECONDFusionFPN(SECONDFPN):
coors[:, 3] / self.downsample_rates[2])
# fusion for each point
out = self.fusion_layer(img_feats, points, out,
downsample_pts_coors, img_meta)
downsample_pts_coors, img_metas)
return [out]
......@@ -51,7 +51,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
@abstractmethod
def forward_train(self,
x,
img_meta,
img_metas,
proposal_list,
gt_bboxes,
gt_labels,
......@@ -64,7 +64,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
def simple_test(self,
x,
proposal_list,
img_meta,
img_metas,
proposals=None,
rescale=False,
**kwargs):
......
......@@ -441,7 +441,7 @@ class PartA2BboxHead(nn.Module):
bbox_pred,
class_labels,
class_pred,
img_meta,
img_metas,
cfg=None):
roi_batch_id = rois[..., 0]
roi_boxes = rois[..., 1:] # boxes without batch id
......@@ -474,8 +474,8 @@ class PartA2BboxHead(nn.Module):
selected_scores = cur_cls_score[selected]
result_list.append(
(img_meta[batch_id]['box_type_3d'](selected_bboxes,
self.bbox_coder.code_size),
(img_metas[batch_id]['box_type_3d'](selected_bboxes,
self.bbox_coder.code_size),
selected_scores, selected_label_preds))
return result_list
......
......@@ -59,7 +59,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return hasattr(self,
'semantic_head') and self.semantic_head is not None
def forward_train(self, feats_dict, voxels_dict, img_meta, proposal_list,
def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
gt_bboxes_3d, gt_labels_3d):
"""Training forward function of PartAggregationROIHead
......@@ -97,7 +97,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return losses
def simple_test(self, feats_dict, voxels_dict, img_meta, proposal_list,
def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
**kwargs):
"""Simple testing forward function of PartAggregationROIHead
......@@ -131,7 +131,7 @@ class PartAggregationROIHead(Base3DRoIHead):
bbox_results['bbox_pred'],
labels_3d,
cls_preds,
img_meta,
img_metas,
cfg=self.test_cfg)
bbox_results = [
......
......@@ -188,7 +188,7 @@ class DynamicVFE(nn.Module):
coors,
points=None,
img_feats=None,
img_meta=None):
img_metas=None):
"""Forward functions
Args:
......@@ -198,7 +198,7 @@ class DynamicVFE(nn.Module):
multi-modality fusion. Defaults to None.
img_feats (list[torch.Tensor], optional): Image fetures used for
multi-modality fusion. Defaults to None.
img_meta (dict, optional): [description]. Defaults to None.
img_metas (dict, optional): [description]. Defaults to None.
Returns:
tuple: If `return_point_feats` is False, returns voxel features and
......@@ -237,7 +237,7 @@ class DynamicVFE(nn.Module):
if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None
and img_feats is not None):
point_feats = self.fusion_layer(img_feats, points, point_feats,
img_meta)
img_metas)
voxel_feats, voxel_coors = self.vfe_scatter(point_feats, coors)
if i != len(self.vfe_layers) - 1:
# need to concat voxel feats if it is not the last vfe
......@@ -351,7 +351,7 @@ class HardVFE(nn.Module):
num_points,
coors,
img_feats=None,
img_meta=None):
img_metas=None):
"""Forward functions
Args:
......@@ -360,7 +360,7 @@ class HardVFE(nn.Module):
coors (torch.Tensor): Coordinates of voxels, shape is Mx(1+NDim).
img_feats (list[torch.Tensor], optional): Image fetures used for
multi-modality fusion. Defaults to None.
img_meta (dict, optional): [description]. Defaults to None.
img_metas (dict, optional): [description]. Defaults to None.
Returns:
tuple: If `return_point_feats` is False, returns voxel features and
......@@ -410,12 +410,12 @@ class HardVFE(nn.Module):
if (self.fusion_layer is not None and img_feats is not None):
voxel_feats = self.fusion_with_mask(features, mask, voxel_feats,
coors, img_feats, img_meta)
coors, img_feats, img_metas)
return voxel_feats
def fusion_with_mask(self, features, mask, voxel_feats, coors, img_feats,
img_meta):
img_metas):
"""Fuse image and point features with mask.
Args:
......@@ -425,7 +425,7 @@ class HardVFE(nn.Module):
voxel_feats (torch.Tensor): Features of voxels.
coors (torch.Tensor): Coordinates of each single voxel.
img_feats (list[torch.Tensor]): Multi-scale feature maps of image.
img_meta (list(dict)): Meta information of image and points.
img_metas (list(dict)): Meta information of image and points.
Returns:
torch.Tensor: Fused features of each voxel.
......@@ -439,7 +439,7 @@ class HardVFE(nn.Module):
point_feats = voxel_feats[mask]
point_feats = self.fusion_layer(img_feats, points, point_feats,
img_meta)
img_metas)
voxel_canvas = voxel_feats.new_zeros(
size=(voxel_feats.size(0), voxel_feats.size(1),
......
......@@ -27,7 +27,7 @@ def test_getitem():
dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 36, 1 / 36],
scale_range=None),
......@@ -50,11 +50,11 @@ def test_getitem():
gt_labels = data['gt_labels_3d']._data
pts_semantic_mask = data['pts_semantic_mask']._data
pts_instance_mask = data['pts_instance_mask']._data
file_name = data['img_meta']._data['file_name']
flip_xz = data['img_meta']._data['flip_xz']
flip_yz = data['img_meta']._data['flip_yz']
rot_angle = data['img_meta']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx']
file_name = data['img_metas']._data['file_name']
flip_xz = data['img_metas']._data['flip_xz']
flip_yz = data['img_metas']._data['flip_yz']
rot_angle = data['img_metas']._data['rot_angle']
sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/scannet/' \
'points/scene0000_00.bin'
assert flip_xz is True
......
......@@ -19,7 +19,7 @@ def test_getitem():
dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 6, 1 / 6],
scale_range=[0.85, 1.15]),
......@@ -39,12 +39,12 @@ def test_getitem():
points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data
gt_labels_3d = data['gt_labels_3d']._data
file_name = data['img_meta']._data['file_name']
flip_xz = data['img_meta']._data['flip_xz']
flip_yz = data['img_meta']._data['flip_yz']
scale_ratio = data['img_meta']._data['scale_ratio']
rot_angle = data['img_meta']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx']
file_name = data['img_metas']._data['file_name']
flip_xz = data['img_metas']._data['flip_xz']
flip_yz = data['img_metas']._data['flip_yz']
scale_ratio = data['img_metas']._data['scale_ratio']
rot_angle = data['img_metas']._data['rot_angle']
sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/sunrgbd' \
'/points/000001.bin'
assert flip_xz is False
......
......@@ -71,7 +71,7 @@ def test_anchor3d_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_head_cfg(
'second/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py')
'second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
......@@ -123,7 +123,7 @@ def test_anchor3d_head_getboxes():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_head_cfg(
'second/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py')
'second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
......
This diff is collapsed.
......@@ -30,7 +30,7 @@ def test_scannet_pipeline():
dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 36, 1 / 36],
scale_range=None),
......@@ -113,7 +113,7 @@ def test_sunrgbd_pipeline():
dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=1.0),
dict(
type='IndoorGlobalRotScale',
type='IndoorGlobalRotScaleTrans',
shift_height=True,
rot_range=[-1 / 6, 1 / 6],
scale_range=[0.85, 1.15]),
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment