test_vote_module.py 670 Bytes
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
import torch


wuyuefeng's avatar
Votenet  
wuyuefeng committed
4
5
def test_vote_module():
    from mmdet3d.models.model_utils import VoteModule
wuyuefeng's avatar
wuyuefeng committed
6

wuyuefeng's avatar
Votenet  
wuyuefeng committed
7
8
9
10
11
12
    vote_loss = dict(
        type='ChamferDistance',
        mode='l1',
        reduction='none',
        loss_dst_weight=10.0)
    self = VoteModule(vote_per_seed=3, in_channels=8, vote_loss=vote_loss)
wuyuefeng's avatar
wuyuefeng committed
13
14
15
16
17
18
19
20
21

    seed_xyz = torch.rand([2, 64, 3], dtype=torch.float32)  # (b, npoints, 3)
    seed_features = torch.rand(
        [2, 8, 64], dtype=torch.float32)  # (b, in_channels, npoints)

    # test forward
    vote_xyz, vote_features = self(seed_xyz, seed_features)
    assert vote_xyz.shape == torch.Size([2, 192, 3])
    assert vote_features.shape == torch.Size([2, 8, 192])