test_misc.py 2.86 KB
Newer Older
1
2
import pytest

Kai Chen's avatar
Kai Chen committed
3
4
import mmcv

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

def test_iter_cast():
    assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3]
    assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0]
    assert mmcv.list_cast([1, 2, 3], str) == ['1', '2', '3']
    assert mmcv.tuple_cast((1, 2, 3), str) == ('1', '2', '3')
    assert next(mmcv.iter_cast([1, 2, 3], str)) == '1'
    with pytest.raises(TypeError):
        mmcv.iter_cast([1, 2, 3], '')
    with pytest.raises(TypeError):
        mmcv.iter_cast(1, str)


def test_is_seq_of():
    assert mmcv.is_seq_of([1.0, 2.0, 3.0], float)
    assert mmcv.is_seq_of([(1, ), (2, ), (3, )], tuple)
    assert mmcv.is_seq_of((1.0, 2.0, 3.0), float)
    assert mmcv.is_list_of([1.0, 2.0, 3.0], float)
    assert not mmcv.is_seq_of((1.0, 2.0, 3.0), float, seq_type=list)
    assert not mmcv.is_tuple_of([1.0, 2.0, 3.0], float)
    assert not mmcv.is_seq_of([1.0, 2, 3], int)
    assert not mmcv.is_seq_of((1.0, 2, 3), int)


def test_slice_list():
    in_list = [1, 2, 3, 4, 5, 6]
    assert mmcv.slice_list(in_list, [1, 2, 3]) == [[1], [2, 3], [4, 5, 6]]
    assert mmcv.slice_list(in_list, [len(in_list)]) == [in_list]
    with pytest.raises(TypeError):
        mmcv.slice_list(in_list, 2.0)
    with pytest.raises(ValueError):
        mmcv.slice_list(in_list, [1, 2])


def test_concat_list():
    assert mmcv.concat_list([[1, 2]]) == [1, 2]
    assert mmcv.concat_list([[1, 2], [3, 4, 5], [6]]) == [1, 2, 3, 4, 5, 6]


def test_requires_package(capsys):

    @mmcv.requires_package('nnn')
    def func_a():
        pass

    @mmcv.requires_package(['numpy', 'n1', 'n2'])
    def func_b():
        pass

    @mmcv.requires_package('six')
    def func_c():
        return 1

    with pytest.raises(RuntimeError):
        func_a()
    out, _ = capsys.readouterr()
    assert out == ('Prerequisites "nnn" are required in method "func_a" but '
                   'not found, please install them first.\n')

    with pytest.raises(RuntimeError):
        func_b()
    out, _ = capsys.readouterr()
    assert out == (
        'Prerequisites "n1, n2" are required in method "func_b" but not found,'
        ' please install them first.\n')

    assert func_c() == 1


def test_requires_executable(capsys):

    @mmcv.requires_executable('nnn')
    def func_a():
        pass

    @mmcv.requires_executable(['ls', 'n1', 'n2'])
    def func_b():
        pass

    @mmcv.requires_executable('mv')
    def func_c():
        return 1

    with pytest.raises(RuntimeError):
        func_a()
    out, _ = capsys.readouterr()
    assert out == ('Prerequisites "nnn" are required in method "func_a" but '
                   'not found, please install them first.\n')

    with pytest.raises(RuntimeError):
        func_b()
    out, _ = capsys.readouterr()
    assert out == (
        'Prerequisites "n1, n2" are required in method "func_b" but not found,'
        ' please install them first.\n')

    assert func_c() == 1