test_progressbar.py 4.99 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
import sys
2
import time
lizz's avatar
lizz committed
3
4
5
6
try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO
7

Kai Chen's avatar
Kai Chen committed
8
9
import pytest

10
11
import mmcv

lizz's avatar
lizz committed
12
13
14
15
16
17

def reset_string_io(io):
    io.truncate(0)
    io.seek(0)


Kai Chen's avatar
Kai Chen committed
18
19
20
if sys.version_info[0] == 2:
    pytest.skip('skipping tests for python 2', allow_module_level=True)

21
22
23

class TestProgressBar(object):

lizz's avatar
lizz committed
24
25
    def test_start(self):
        out = StringIO()
26
27
        bar_width = 20
        # without total task num
lizz's avatar
lizz committed
28
29
30
31
32
33
        prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
        assert out.getvalue() == 'completed: 0, elapsed: 0s'
        reset_string_io(out)
        prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
        assert out.getvalue() == ''
        reset_string_io(out)
34
        prog_bar.start()
lizz's avatar
lizz committed
35
        assert out.getvalue() == 'completed: 0, elapsed: 0s'
36
        # with total task num
lizz's avatar
lizz committed
37
38
39
40
41
42
43
44
45
        reset_string_io(out)
        prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
        assert out.getvalue() == '[{}] 0/10, elapsed: 0s, ETA:'.format(
            ' ' * bar_width)
        reset_string_io(out)
        prog_bar = mmcv.ProgressBar(
            10, bar_width=bar_width, start=False, file=out)
        assert out.getvalue() == ''
        reset_string_io(out)
46
        prog_bar.start()
lizz's avatar
lizz committed
47
48
        assert out.getvalue() == '[{}] 0/10, elapsed: 0s, ETA:'.format(
            ' ' * bar_width)
49

lizz's avatar
lizz committed
50
51
    def test_update(self):
        out = StringIO()
52
53
        bar_width = 20
        # without total task num
lizz's avatar
lizz committed
54
        prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
55
        time.sleep(1)
lizz's avatar
lizz committed
56
        reset_string_io(out)
57
        prog_bar.update()
lizz's avatar
lizz committed
58
59
        assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
        reset_string_io(out)
60
        # with total task num
lizz's avatar
lizz committed
61
        prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
62
        time.sleep(1)
lizz's avatar
lizz committed
63
        reset_string_io(out)
64
        prog_bar.update()
lizz's avatar
lizz committed
65
66
67
        assert out.getvalue() == ('\r[{}] 1/10, 1.0 task/s, '
                                  'elapsed: 1s, ETA:     9s'.format('>' * 2 +
                                                                    ' ' * 18))
68
69
70
71
72
73
74


def sleep_1s(num):
    time.sleep(1)
    return num


lizz's avatar
lizz committed
75
76
77
78
79
80
81
82
def test_track_progress_list():
    out = StringIO()
    ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
83
84
85
    assert ret == [1, 2, 3]


lizz's avatar
lizz committed
86
87
def test_track_progress_iterator():
    out = StringIO()
88
    ret = mmcv.track_progress(
lizz's avatar
lizz committed
89
90
91
92
93
94
        sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
95
96
97
    assert ret == [1, 2, 3]


lizz's avatar
lizz committed
98
99
def test_track_iter_progress():
    out = StringIO()
ZwwWayne's avatar
ZwwWayne committed
100
    ret = []
lizz's avatar
lizz committed
101
    for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
ZwwWayne's avatar
ZwwWayne committed
102
        ret.append(sleep_1s(num))
lizz's avatar
lizz committed
103
104
105
106
107
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
ZwwWayne's avatar
ZwwWayne committed
108
109
110
    assert ret == [1, 2, 3]


lizz's avatar
lizz committed
111
112
def test_track_enum_progress():
    out = StringIO()
ZwwWayne's avatar
ZwwWayne committed
113
114
    ret = []
    count = []
lizz's avatar
lizz committed
115
116
    for i, num in enumerate(
            mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
ZwwWayne's avatar
ZwwWayne committed
117
118
        ret.append(sleep_1s(num))
        count.append(i)
lizz's avatar
lizz committed
119
120
121
122
123
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
ZwwWayne's avatar
ZwwWayne committed
124
125
126
127
    assert ret == [1, 2, 3]
    assert count == [0, 1, 2]


lizz's avatar
lizz committed
128
129
def test_track_parallel_progress_list():
    out = StringIO()
130
    results = mmcv.track_parallel_progress(
lizz's avatar
lizz committed
131
132
133
134
135
136
137
        sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
    assert out.getvalue() == (
        '[    ] 0/4, elapsed: 0s, ETA:'
        '\r[>   ] 1/4, 1.0 task/s, elapsed: 1s, ETA:     3s'
        '\r[>>  ] 2/4, 2.0 task/s, elapsed: 1s, ETA:     1s'
        '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA:     0s\n')
138
139
140
    assert results == [1, 2, 3, 4]


lizz's avatar
lizz committed
141
142
def test_track_parallel_progress_iterator():
    out = StringIO()
143
    results = mmcv.track_parallel_progress(
lizz's avatar
lizz committed
144
145
146
147
148
149
150
        sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
    assert out.getvalue() == (
        '[    ] 0/4, elapsed: 0s, ETA:'
        '\r[>   ] 1/4, 1.0 task/s, elapsed: 1s, ETA:     3s'
        '\r[>>  ] 2/4, 2.0 task/s, elapsed: 1s, ETA:     1s'
        '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA:     0s\n')
151
    assert results == [1, 2, 3, 4]