test_progressbar.py 5.54 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
lizz's avatar
lizz committed
2
import os
3
import time
4
5
from io import StringIO
from unittest.mock import patch
Wenwei Zhang's avatar
Wenwei Zhang committed
6

7
import mmcv
8

lizz's avatar
lizz committed
9
10
11
12
13
14

def reset_string_io(io):
    io.truncate(0)
    io.seek(0)


lizz's avatar
lizz committed
15
class TestProgressBar:
16

lizz's avatar
lizz committed
17
18
    def test_start(self):
        out = StringIO()
19
20
        bar_width = 20
        # without total task num
lizz's avatar
lizz committed
21
22
23
24
25
26
        prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
        assert out.getvalue() == 'completed: 0, elapsed: 0s'
        reset_string_io(out)
        prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
        assert out.getvalue() == ''
        reset_string_io(out)
27
        prog_bar.start()
lizz's avatar
lizz committed
28
        assert out.getvalue() == 'completed: 0, elapsed: 0s'
29
        # with total task num
lizz's avatar
lizz committed
30
31
        reset_string_io(out)
        prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
Cao Yuhang's avatar
Cao Yuhang committed
32
        assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
lizz's avatar
lizz committed
33
34
35
36
37
        reset_string_io(out)
        prog_bar = mmcv.ProgressBar(
            10, bar_width=bar_width, start=False, file=out)
        assert out.getvalue() == ''
        reset_string_io(out)
38
        prog_bar.start()
Cao Yuhang's avatar
Cao Yuhang committed
39
        assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
40

lizz's avatar
lizz committed
41
42
    def test_update(self):
        out = StringIO()
43
44
        bar_width = 20
        # without total task num
lizz's avatar
lizz committed
45
        prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
46
        time.sleep(1)
lizz's avatar
lizz committed
47
        reset_string_io(out)
48
        prog_bar.update()
lizz's avatar
lizz committed
49
50
        assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
        reset_string_io(out)
51
        # with total task num
lizz's avatar
lizz committed
52
        prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
53
        time.sleep(1)
lizz's avatar
lizz committed
54
        reset_string_io(out)
55
        prog_bar.update()
Cao Yuhang's avatar
Cao Yuhang committed
56
57
        assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \
                                 'task/s, elapsed: 1s, ETA:     9s'
58

lizz's avatar
lizz committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def test_adaptive_length(self):
        with patch.dict('os.environ', {'COLUMNS': '80'}):
            out = StringIO()
            bar_width = 20
            prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
            time.sleep(1)
            reset_string_io(out)
            prog_bar.update()
            assert len(out.getvalue()) == 66

            os.environ['COLUMNS'] = '30'
            reset_string_io(out)
            prog_bar.update()
            assert len(out.getvalue()) == 48

            os.environ['COLUMNS'] = '60'
            reset_string_io(out)
            prog_bar.update()
            assert len(out.getvalue()) == 60

79
80
81
82
83
84

def sleep_1s(num):
    time.sleep(1)
    return num


lizz's avatar
lizz committed
85
86
87
88
89
90
91
92
def test_track_progress_list():
    out = StringIO()
    ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
93
94
95
    assert ret == [1, 2, 3]


lizz's avatar
lizz committed
96
97
def test_track_progress_iterator():
    out = StringIO()
98
    ret = mmcv.track_progress(
lizz's avatar
lizz committed
99
100
101
102
103
104
        sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
105
106
107
    assert ret == [1, 2, 3]


lizz's avatar
lizz committed
108
109
def test_track_iter_progress():
    out = StringIO()
ZwwWayne's avatar
ZwwWayne committed
110
    ret = []
lizz's avatar
lizz committed
111
    for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
ZwwWayne's avatar
ZwwWayne committed
112
        ret.append(sleep_1s(num))
lizz's avatar
lizz committed
113
114
115
116
117
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
ZwwWayne's avatar
ZwwWayne committed
118
119
120
    assert ret == [1, 2, 3]


lizz's avatar
lizz committed
121
122
def test_track_enum_progress():
    out = StringIO()
ZwwWayne's avatar
ZwwWayne committed
123
124
    ret = []
    count = []
lizz's avatar
lizz committed
125
126
    for i, num in enumerate(
            mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
ZwwWayne's avatar
ZwwWayne committed
127
128
        ret.append(sleep_1s(num))
        count.append(i)
lizz's avatar
lizz committed
129
130
131
132
133
    assert out.getvalue() == (
        '[   ] 0/3, elapsed: 0s, ETA:'
        '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
        '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA:     1s'
        '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA:     0s\n')
ZwwWayne's avatar
ZwwWayne committed
134
135
136
137
    assert ret == [1, 2, 3]
    assert count == [0, 1, 2]


lizz's avatar
lizz committed
138
139
def test_track_parallel_progress_list():
    out = StringIO()
140
    results = mmcv.track_parallel_progress(
lizz's avatar
lizz committed
141
        sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
142
143
144
145
146
147
148
    # The following cannot pass CI on Github Action
    # assert out.getvalue() == (
    #     '[    ] 0/4, elapsed: 0s, ETA:'
    #     '\r[>   ] 1/4, 1.0 task/s, elapsed: 1s, ETA:     3s'
    #     '\r[>>  ] 2/4, 2.0 task/s, elapsed: 1s, ETA:     1s'
    #     '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA:     1s'
    #     '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA:     0s\n')
149
150
151
    assert results == [1, 2, 3, 4]


lizz's avatar
lizz committed
152
153
def test_track_parallel_progress_iterator():
    out = StringIO()
154
    results = mmcv.track_parallel_progress(
lizz's avatar
lizz committed
155
        sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
156
157
158
159
160
161
162
    # The following cannot pass CI on Github Action
    # assert out.getvalue() == (
    #     '[    ] 0/4, elapsed: 0s, ETA:'
    #     '\r[>   ] 1/4, 1.0 task/s, elapsed: 1s, ETA:     3s'
    #     '\r[>>  ] 2/4, 2.0 task/s, elapsed: 1s, ETA:     1s'
    #     '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA:     1s'
    #     '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA:     0s\n')
163
    assert results == [1, 2, 3, 4]