test_voting_module.py 507 Bytes
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch


def test_voting_module():
    from mmdet3d.ops import VoteModule

    self = VoteModule(vote_per_seed=3, in_channels=8)

    seed_xyz = torch.rand([2, 64, 3], dtype=torch.float32)  # (b, npoints, 3)
    seed_features = torch.rand(
        [2, 8, 64], dtype=torch.float32)  # (b, in_channels, npoints)

    # test forward
    vote_xyz, vote_features = self(seed_xyz, seed_features)
    assert vote_xyz.shape == torch.Size([2, 192, 3])
    assert vote_features.shape == torch.Size([2, 8, 192])