"""Tests the core function of vote fusion. CommandLine: pytest tests/test_models/test_fusion/test_vote_fusion.py """ import torch from mmdet3d.models.fusion_layers import VoteFusion def test_vote_fusion(): img_meta = { 'ori_shape': (530, 730, 3), 'img_shape': (600, 826, 3), 'pad_shape': (608, 832, 3), 'scale_factor': torch.tensor([1.1315, 1.1321, 1.1315, 1.1321]), 'flip': False, 'pcd_horizontal_flip': False, 'pcd_vertical_flip': False, 'pcd_trans': torch.tensor([0., 0., 0.]), 'pcd_scale_factor': 1.0308290128214932, 'pcd_rotation': torch.tensor([[0.9747, 0.2234, 0.0000], [-0.2234, 0.9747, 0.0000], [0.0000, 0.0000, 1.0000]]), 'transformation_3d_flow': ['HF', 'R', 'S', 'T'] } calibs = { 'Rt': torch.tensor([[[0.979570, 0.047954, -0.195330], [0.047954, 0.887470, 0.458370], [0.195330, -0.458370, 0.867030]]]), 'K': torch.tensor([[[529.5000, 0.0000, 365.0000], [0.0000, 529.5000, 265.0000], [0.0000, 0.0000, 1.0000]]]) } bboxes = torch.tensor([[[ 5.4286e+02, 9.8283e+01, 6.1700e+02, 1.6742e+02, 9.7922e-01, 3.0000e+00 ], [ 4.2613e+02, 8.4646e+01, 4.9091e+02, 1.6237e+02, 9.7848e-01, 3.0000e+00 ], [ 2.5606e+02, 7.3244e+01, 3.7883e+02, 1.8471e+02, 9.7317e-01, 3.0000e+00 ], [ 6.0104e+02, 1.0648e+02, 6.6757e+02, 1.9216e+02, 8.4607e-01, 3.0000e+00 ], [ 2.2923e+02, 1.4984e+02, 7.0163e+02, 4.6537e+02, 3.5719e-01, 0.0000e+00 ], [ 2.5614e+02, 7.4965e+01, 3.3275e+02, 1.5908e+02, 2.8688e-01, 3.0000e+00 ], [ 9.8718e+00, 1.4142e+02, 2.0213e+02, 3.3878e+02, 1.0935e-01, 3.0000e+00 ], [ 6.1930e+02, 1.1768e+02, 6.8505e+02, 2.0318e+02, 1.0720e-01, 3.0000e+00 ]]]) seeds_3d = torch.tensor([[[0.044544, 1.675476, -1.531831], [2.500625, 7.238662, -0.737675], [-0.600003, 4.827733, -0.084022], [1.396212, 3.994484, -1.551180], [-2.054746, 2.012759, -0.357472], [-0.582477, 6.580470, -1.466052], [1.313331, 5.722039, 0.123904], [-1.107057, 3.450359, -1.043422], [1.759746, 5.655951, -1.519564], [-0.203003, 6.453243, 0.137703], [-0.910429, 0.904407, -0.512307], [0.434049, 3.032374, -0.763842], [1.438146, 2.289263, -1.546332], [0.575622, 5.041906, -0.891143], [-1.675931, 1.417597, -1.588347]]]) imgs = torch.linspace( -1, 1, steps=608 * 832).reshape(1, 608, 832).repeat(3, 1, 1)[None] expected_tensor1 = torch.tensor( [[[ 0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00, 1.193706e-01, -0.000000e+00, -2.879214e-01, -0.000000e+00, 0.000000e+00, 1.422463e-01, -6.474612e-01, -0.000000e+00, 1.490057e-02, 0.000000e+00 ], [ 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, 0.000000e+00, -1.873745e+00, -0.000000e+00, 1.576240e-01, 0.000000e+00, -0.000000e+00, -3.646177e-02, -7.751858e-01, 0.000000e+00, 9.593642e-02, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, -6.263277e-02, 0.000000e+00, -3.646387e-01, 0.000000e+00, 0.000000e+00, -5.875812e-01, -6.263450e-02, 0.000000e+00, 1.149264e-01, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 8.899736e-01, 0.000000e+00, 9.019017e-01, 0.000000e+00, 0.000000e+00, 6.917775e-01, 8.899733e-01, 0.000000e+00, 9.812444e-01, 0.000000e+00 ], [ -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -4.516903e-01, -0.000000e+00, -2.315422e-01, -0.000000e+00, -0.000000e+00, -4.197519e-01, -4.516906e-01, -0.000000e+00, -1.547615e-01, -0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 3.571937e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 3.571937e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 9.731653e-01, 0.000000e+00, 0.000000e+00, 1.093455e-01, 0.000000e+00, 0.000000e+00, 8.460656e-01, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, 2.540967e-03, -1.834944e-03, 1.032048e-03 ], [ 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, 2.540967e-03, -1.834944e-03, 1.032048e-03 ], [ 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, 2.540967e-03, -1.834944e-03, 1.032048e-03 ]]]) expected_tensor2 = torch.tensor([[ False, False, False, False, False, True, False, True, False, False, True, True, False, True, False, False, False, False, False, False, False, False, True, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False ]]) expected_tensor3 = torch.tensor( [[[ -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, -0.000000e+00, 1.720988e-01, 0.000000e+00 ], [ 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 4.824460e-02, 0.000000e+00 ], [ -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, 1.447314e-01, -0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 9.759269e-01, 0.000000e+00 ], [ -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, -1.631542e-01, -0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.072001e-01, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00 ], [ 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, 2.540967e-03, -1.834944e-03, 1.032048e-03 ], [ 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, 2.540967e-03, -1.834944e-03, 1.032048e-03 ], [ 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, 2.540967e-03, -1.834944e-03, 1.032048e-03 ]]]) fusion = VoteFusion() out1, out2 = fusion(imgs, bboxes, seeds_3d, [img_meta], calibs) assert torch.allclose(expected_tensor1, out1[:, :, :15], 1e-3) assert torch.allclose(expected_tensor2.float(), out2.float(), 1e-3) assert torch.allclose(expected_tensor3, out1[:, :, 30:45], 1e-3) out1, out2 = fusion(imgs, bboxes[:, :2], seeds_3d, [img_meta], calibs) out1 = out1[:, :15, 30:45] out2 = out2[:, 30:45].float() assert torch.allclose(torch.zeros_like(out1), out1, 1e-3) assert torch.allclose(torch.zeros_like(out2), out2, 1e-3)