test_utils.py 321 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import torch

from mmdet3d.core import draw_heatmap_gaussian


def test_gaussian():
    heatmap = torch.zeros((128, 128))
    ct_int = torch.tensor([64, 64], dtype=torch.int32)
    radius = 2
    draw_heatmap_gaussian(heatmap, ct_int, radius)
    assert torch.isclose(torch.sum(heatmap), torch.tensor(4.3505), atol=1e-3)