test_sparse_unetv2.py 1.3 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch


def test_SparseUnetV2():
    from mmdet3d.models.middle_encoders.sparse_unetv2 import SparseUnetV2
    self = SparseUnetV2(
        in_channels=4, output_shape=[41, 1600, 1408], pre_act=False)
    voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
                                   [6.8162713, -2.480431, -1.3616394, 0.36],
                                   [11.643568, -4.744306, -1.3580885, 0.16],
                                   [23.482342, 6.5036807, 0.5806964, 0.35]],
                                  dtype=torch.float32)  # n, point_features
    coordinates = torch.tensor(
        [[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
         [1, 35, 930, 469]],
        dtype=torch.int32)  # n, 4(batch, ind_x, ind_y, ind_z)

    unet_ret_dict = self.forward(voxel_features, coordinates, 2)
    seg_cls_preds = unet_ret_dict['u_seg_preds']
    seg_reg_preds = unet_ret_dict['u_reg_preds']
    seg_features = unet_ret_dict['seg_features']
    spatial_features = unet_ret_dict['spatial_features']

    assert seg_cls_preds.shape == torch.Size([4, 1])
    assert seg_reg_preds.shape == torch.Size([4, 3])
    assert seg_features.shape == torch.Size([4, 16])
    assert spatial_features.shape == torch.Size([2, 256, 200, 176])


if __name__ == '__main__':
    test_SparseUnetV2()