test_vote_module.py 1.3 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
import torch


wuyuefeng's avatar
Votenet  
wuyuefeng committed
4
5
def test_vote_module():
    from mmdet3d.models.model_utils import VoteModule
wuyuefeng's avatar
wuyuefeng committed
6

wuyuefeng's avatar
Votenet  
wuyuefeng committed
7
8
9
10
11
12
    vote_loss = dict(
        type='ChamferDistance',
        mode='l1',
        reduction='none',
        loss_dst_weight=10.0)
    self = VoteModule(vote_per_seed=3, in_channels=8, vote_loss=vote_loss)
wuyuefeng's avatar
wuyuefeng committed
13
14
15
16
17
18

    seed_xyz = torch.rand([2, 64, 3], dtype=torch.float32)  # (b, npoints, 3)
    seed_features = torch.rand(
        [2, 8, 64], dtype=torch.float32)  # (b, in_channels, npoints)

    # test forward
19
    vote_xyz, vote_features, vote_offset = self(seed_xyz, seed_features)
wuyuefeng's avatar
wuyuefeng committed
20
21
    assert vote_xyz.shape == torch.Size([2, 192, 3])
    assert vote_features.shape == torch.Size([2, 8, 192])
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    assert vote_offset.shape == torch.Size([2, 3, 192])

    # test clip offset and without feature residual
    self = VoteModule(
        vote_per_seed=1,
        in_channels=8,
        num_points=32,
        with_res_feat=False,
        vote_xyz_range=(2.0, 2.0, 2.0))

    vote_xyz, vote_features, vote_offset = self(seed_xyz, seed_features)
    assert vote_xyz.shape == torch.Size([2, 32, 3])
    assert vote_features.shape == torch.Size([2, 8, 32])
    assert vote_offset.shape == torch.Size([2, 3, 32])
    assert torch.allclose(seed_features[..., :32], vote_features)
    assert vote_offset.max() <= 2.0
    assert vote_offset.min() >= -2.0