test_image_misc.py 1.4 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import mmcv

Cao Yuhang's avatar
Cao Yuhang committed
8
9
10
11
try:
    import torch
except ImportError:
    torch = None
12

Cao Yuhang's avatar
Cao Yuhang committed
13
14

@pytest.mark.skipif(torch is None, reason='requires torch library')
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def test_tensor2imgs():

    # test tensor obj
    with pytest.raises(AssertionError):
        tensor = np.random.rand(2, 3, 3)
        mmcv.tensor2imgs(tensor)

    # test tensor ndim
    with pytest.raises(AssertionError):
        tensor = torch.randn(2, 3, 3)
        mmcv.tensor2imgs(tensor)

    # test mean length
    with pytest.raises(AssertionError):
        tensor = torch.randn(2, 3, 5, 5)
        mmcv.tensor2imgs(tensor, mean=(1, ))

    # test std length
    with pytest.raises(AssertionError):
        tensor = torch.randn(2, 3, 5, 5)
        mmcv.tensor2imgs(tensor, std=(1, ))

    # test rgb=True
    tensor = torch.randn(2, 3, 5, 5)
    gts = [
        t.cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
        for t in tensor.flip(1)
    ]
    outputs = mmcv.tensor2imgs(tensor, to_rgb=True)
    for gt, output in zip(gts, outputs):
        assert_array_equal(gt, output)

    # test rgb=False
    tensor = torch.randn(2, 3, 5, 5)
    gts = [t.cpu().numpy().transpose(1, 2, 0).astype(np.uint8) for t in tensor]
    outputs = mmcv.tensor2imgs(tensor, to_rgb=False)
    for gt, output in zip(gts, outputs):
        assert_array_equal(gt, output)