test_multibin_loss.py 1.16 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

5
from mmdet3d.registry import MODELS
VVsssssk's avatar
VVsssssk committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


def test_multibin_loss():
    from mmdet3d.models.losses import MultiBinLoss

    # reduction should be in ['none', 'mean', 'sum']
    with pytest.raises(AssertionError):
        multibin_loss = MultiBinLoss(reduction='l2')

    pred = torch.tensor([[
        0.81, 0.32, 0.78, 0.52, 0.24, 0.12, 0.32, 0.11, 1.20, 1.30, 0.20, 0.11,
        0.12, 0.11, 0.23, 0.31
    ],
                         [
                             0.02, 0.19, 0.78, 0.22, 0.31, 0.12, 0.22, 0.11,
                             1.20, 1.30, 0.45, 0.51, 0.12, 0.11, 0.13, 0.61
                         ]])
    target = torch.tensor([[1, 1, 0, 0, 2.14, 3.12, 0.68, -2.15],
                           [1, 1, 0, 0, 3.12, 3.12, 2.34, 1.23]])
    multibin_loss_cfg = dict(
        type='MultiBinLoss', reduction='none', loss_weight=1.0)
27
    multibin_loss = MODELS.build(multibin_loss_cfg)
VVsssssk's avatar
VVsssssk committed
28
29
30
31
    output_multibin_loss = multibin_loss(pred, target, num_dir_bins=4)
    expected_multibin_loss = torch.tensor(2.1120)
    assert torch.allclose(
        output_multibin_loss, expected_multibin_loss, atol=1e-4)