import numpy as np

from mmdet3d.core.evaluation.indoor_eval import indoor_eval


def test_indoor_eval():
    det_infos = [[[[
        4.0,
        [
            2.8734498, -0.187645, -0.02600911, 0.6761766, 0.56542563,
            0.5953976, 0.
        ], 0.9980684
    ],
                   [
                       4.0,
                       [
                           0.4031701, -3.2346897, 0.07118589, 0.73209894,
                           0.8711227, 0.5148243, 0.
                       ], 0.9747082
                   ],
                   [
                       3.0,
                       [
                           -1.274147, -2.351935, 0.07428858, 1.4534658,
                           2.563081, 0.8587492, 0.
                       ], 0.9709939
                   ],
                   [
                       17.0,
                       [
                           3.2214177, 0.7899204, 0.03836718, 0.05321002,
                           1.2607929, 0.1411697, 0.
                       ], 0.9482147
                   ],
                   [
                       2.0,
                       [
                           -1.6804854, 2.399011, -0.13099639, 0.5608963,
                           0.5052759, 0.6770297, 0.
                       ], 0.84311247
                   ]]],
                 [[[
                     17.0,
                     [
                         3.2112048e+00, 5.6918913e-01, -8.6143613e-04,
                         1.1942449e-01, 1.2988183e+00, 1.9952521e-01,
                         0.0000000e+00
                     ], 0.9965866
                 ],
                   [
                       17.0,
                       [
                           3.248133, 0.4324184, 0.20038621, 0.17225507,
                           1.2736976, 0.32598814, 0.
                       ], 0.99507546
                   ],
                   [
                       3.0,
                       [
                           -1.2793612, -2.3155289, 0.15598366, 1.2822601,
                           2.2253945, 0.8361754, 0.
                       ], 0.9916463
                   ],
                   [
                       4.0,
                       [
                           2.8716104, -0.26416883, -0.04933786, 0.8190681,
                           0.60294986, 0.5769499, 0.
                       ], 0.9702634
                   ],
                   [
                       17.0,
                       [
                           -2.2109854, 0.19445783, -0.01614259, 0.40659013,
                           0.35370222, 0.3290567, 0.
                       ], 0.95803124
                   ]]]]

    label2cat = {
        0: 'cabinet',
        1: 'bed',
        2: 'chair',
        3: 'sofa',
        4: 'table',
        5: 'door',
        6: 'window',
        7: 'bookshelf',
        8: 'picture',
        9: 'counter',
        10: 'desk',
        11: 'curtain',
        12: 'refrigerator',
        13: 'showercurtrain',
        14: 'toilet',
        15: 'sink',
        16: 'bathtub',
        17: 'garbagebin'
    }
    gt_annos = [{
        'gt_num':
        12,
        'gt_boxes_upright_depth':
        np.array([[
            2.54621506, -0.89397144, 0.54144311, 2.90430856, 1.78370309,
            0.93826824
        ],
                  [
                      3.36553669, 0.31014189, 0.38758934, 1.2504847,
                      0.71281439, 0.3908577
                  ],
                  [
                      0.17272574, 2.90289116, 0.27966365, 0.56292468,
                      0.8512187, 0.4987641
                  ],
                  [
                      2.39521956, 1.67557895, 0.40407273, 1.23511314,
                      0.49469376, 0.62720448
                  ],
                  [
                      -2.41815996, -1.69104958, 0.22304082, 0.55816364,
                      0.48154473, 0.66580439
                  ],
                  [
                      -0.18044823, 2.9227581, 0.24480903, 0.36165208,
                      0.44468427, 0.53103662
                  ],
                  [
                      -2.44398379, -2.1610918, 0.23631772, 0.52229881,
                      0.63388562, 0.66596919
                  ],
                  [
                      -2.01452827, -2.9558928, 0.8139953, 1.61732554,
                      0.60224247, 1.79295814
                  ],
                  [
                      -0.61519569, 3.24365234, 1.24335742, 2.11988783,
                      0.26006722, 1.77748263
                  ],
                  [
                      -2.64330673, 0.59929442, 1.59422684, 0.07352924,
                      0.28620502, 0.35408139
                  ],
                  [
                      -0.58128822, 3.23699641, 0.06050609, 1.94151425,
                      0.16413498, 0.20168215
                  ],
                  [
                      0.15343043, 2.24693251, 0.22470728, 0.49632657,
                      0.47379827, 0.43063563
                  ]]),
        'class': [3, 4, 4, 17, 2, 2, 2, 7, 11, 8, 17, 2]
    }, {
        'gt_num':
        12,
        'gt_boxes_upright_depth':
        np.array([[
            3.48649406, 0.24238291, 0.48358256, 1.34014034, 0.72744983,
            0.40819243
        ],
                  [
                      -0.50371504, 3.25293231, 1.25988698, 2.12330937,
                      0.27563906, 1.80230701
                  ],
                  [
                      2.58820581, -0.99452347, 0.57732373, 2.94801593,
                      1.67463434, 0.88743341
                  ],
                  [
                      -1.9116497, -2.88811016, 0.70502496, 1.62386703,
                      0.60732293, 1.5857985
                  ],
                  [
                      -2.55324745, 0.6909315, 1.59045517, 0.07264495,
                      0.32018459, 0.3506999
                  ],
                  [
                      -2.3436017, -2.1659112, 0.254318, 0.5333302, 0.56154585,
                      0.64904487
                  ],
                  [
                      -2.32046795, -1.6880455, 0.26138437, 0.5586133,
                      0.59743834, 0.6378752
                  ],
                  [
                      -0.46495372, 3.22126102, 0.03188983, 1.92557108,
                      0.15160203, 0.24680007
                  ],
                  [
                      0.28087699, 2.88433838, 0.2495866, 0.57001019,
                      0.85177159, 0.5689255
                  ],
                  [
                      -0.05292395, 2.90586925, 0.23064148, 0.39113954,
                      0.43746281, 0.52981442
                  ],
                  [
                      0.25537968, 2.25156307, 0.24932587, 0.48192862,
                      0.51398182, 0.38040417
                  ],
                  [
                      2.60432816, 1.62303996, 0.42025632, 1.23775268,
                      0.51761389, 0.66034317
                  ]]),
        'class': [4, 11, 3, 7, 8, 2, 2, 17, 4, 2, 2, 17]
    }]

    ret_value = indoor_eval(gt_annos, det_infos, [0.25, 0.5], label2cat)
    garbagebin_AP_25 = ret_value['garbagebin_AP_0.25']
    sofa_AP_25 = ret_value['sofa_AP_0.25']
    table_AP_25 = ret_value['table_AP_0.25']
    chair_AP_25 = ret_value['chair_AP_0.25']
    mAP_25 = ret_value['mAP_0.25']
    garbagebin_rec_25 = ret_value['garbagebin_rec_0.25']
    sofa_rec_25 = ret_value['sofa_rec_0.25']
    table_rec_25 = ret_value['table_rec_0.25']
    chair_rec_25 = ret_value['chair_rec_0.25']
    mAR_25 = ret_value['mAR_0.25']
    sofa_AP_50 = ret_value['sofa_AP_0.50']
    table_AP_50 = ret_value['table_AP_0.50']
    chair_AP_50 = ret_value['chair_AP_0.50']
    mAP_50 = ret_value['mAP_0.50']
    sofa_rec_50 = ret_value['sofa_rec_0.50']
    table_rec_50 = ret_value['table_rec_0.50']
    chair_rec_50 = ret_value['chair_rec_0.50']
    mAR_50 = ret_value['mAR_0.50']
    assert garbagebin_AP_25 == 0.25
    assert sofa_AP_25 == 1.0
    assert table_AP_25 == 0.75
    assert chair_AP_25 == 0.125
    assert abs(mAP_25 - 0.303571) < 0.001
    assert garbagebin_rec_25 == 0.25
    assert sofa_rec_25 == 1.0
    assert table_rec_25 == 0.75
    assert chair_rec_25 == 0.125
    assert abs(mAR_25 - 0.303571) < 0.001
    assert sofa_AP_50 == 0.25
    assert abs(table_AP_50 - 0.416667) < 0.001
    assert chair_AP_50 == 0.125
    assert abs(mAP_50 - 0.113095) < 0.001
    assert sofa_rec_50 == 0.5
    assert table_rec_50 == 0.5
    assert chair_rec_50 == 0.125
    assert abs(mAR_50 - 0.160714) < 0.001
