Unverified Commit 7b1f6a85 authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

add primitive head (#53)

* add primitive head

* register of primitive head

* modify primitive head

* modify primitive head

* modify primitive head

* modify primitive head

* update primitive head unittest

* modify primitive had

* fix bugs for primitive head

* update primitive head
parent ba5fa548
......@@ -126,7 +126,7 @@ class VoteModule(nn.Module):
seed_indices_expand = seed_indices.unsqueeze(-1).repeat(
1, 1, 3 * self.gt_per_seed)
seed_gt_votes = torch.gather(vote_targets, 1, seed_indices_expand)
seed_gt_votes += seed_points.repeat(1, 1, 3)
seed_gt_votes += seed_points.repeat(1, 1, self.gt_per_seed)
weight = seed_gt_votes_mask / (torch.sum(seed_gt_votes_mask) + 1e-6)
distance = self.vote_loss(
......
from .pointwise_semantic_head import PointwiseSemanticHead
from .primitive_head import PrimitiveHead
__all__ = ['PointwiseSemanticHead']
__all__ = ['PointwiseSemanticHead', 'PrimitiveHead']
This diff is collapsed.
......@@ -457,3 +457,115 @@ def test_free_anchor_3D_head():
gt_labels, input_metas, None)
assert losses['positive_bag_loss'] >= 0
assert losses['negative_bag_loss'] >= 0
def test_primitive_head():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
primitive_head_cfg = dict(
type='PrimitiveHead',
num_dims=2,
num_classes=18,
primitive_mode='z',
vote_moudule_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
vote_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
num_point=64,
radius=0.3,
num_sample=16,
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=1.0),
center_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=1.0,
loss_dst_weight=1.0),
semantic_reg_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=1.0,
loss_dst_weight=1.0),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
self = build_head(primitive_head_cfg).cuda()
fp_xyz = [torch.rand([2, 64, 3], dtype=torch.float32).cuda()]
hd_features = torch.rand([2, 256, 64], dtype=torch.float32).cuda()
fp_indices = [torch.randint(0, 64, [2, 64]).cuda()]
input_dict = dict(
fp_xyz_net0=fp_xyz, hd_feature=hd_features, fp_indices_net0=fp_indices)
# test forward
ret_dict = self(input_dict, 'vote')
assert ret_dict['center_z'].shape == torch.Size([2, 64, 3])
assert ret_dict['size_residuals_z'].shape == torch.Size([2, 64, 2])
assert ret_dict['sem_cls_scores_z'].shape == torch.Size([2, 64, 18])
assert ret_dict['aggregated_points_z'].shape == torch.Size([2, 64, 3])
# test loss
points = torch.rand([2, 1024, 3], dtype=torch.float32).cuda()
ret_dict['seed_points'] = fp_xyz[0]
ret_dict['seed_indices'] = fp_indices[0]
from mmdet3d.core.bbox import DepthInstance3DBoxes
gt_bboxes_3d = [
DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda()),
DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda())
]
gt_labels_3d = torch.randint(0, 18, [2, 4]).cuda()
gt_labels_3d = [gt_labels_3d[0], gt_labels_3d[1]]
pts_semantic_mask = torch.randint(0, 19, [2, 1024]).cuda()
pts_semantic_mask = [pts_semantic_mask[0], pts_semantic_mask[1]]
pts_instance_mask = torch.randint(0, 4, [2, 1024]).cuda()
pts_instance_mask = [pts_instance_mask[0], pts_instance_mask[1]]
loss_input_dict = dict(
bbox_preds=ret_dict,
points=points,
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
pts_semantic_mask=pts_semantic_mask,
pts_instance_mask=pts_instance_mask)
losses_dict = self.loss(**loss_input_dict)
assert losses_dict['flag_loss_z'] >= 0
assert losses_dict['vote_loss_z'] >= 0
assert losses_dict['center_loss_z'] >= 0
assert losses_dict['size_loss_z'] >= 0
assert losses_dict['sem_loss_z'] >= 0
# 'Primitive_mode' should be one of ['z', 'xy', 'line']
with pytest.raises(AssertionError):
primitive_head_cfg['vote_moudule_cfg']['in_channels'] = 'xyz'
build_head(primitive_head_cfg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment