Commit 25b422ab authored by lizz's avatar lizz Committed by Kai Chen
Browse files

Let progress goes to stderr (#115)



* let progress goes to stderr
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Add argument file=sys.stdout
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* For the sake of the stupid pytest
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* line width < 120 is anti-humane
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* flake8
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Lets try
Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent d5865e0c
......@@ -8,12 +8,13 @@ from .timer import Timer
class ProgressBar(object):
"""A progress bar which can print the progress"""
def __init__(self, task_num=0, bar_width=50, start=True):
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (
bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0
self.file = file
if start:
self.start()
......@@ -33,11 +34,11 @@ class ProgressBar(object):
def start(self):
if self.task_num > 0:
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:'.format(
self.file.write('[{}] 0/{}, elapsed: 0s, ETA:'.format(
' ' * self.bar_width, self.task_num))
else:
sys.stdout.write('completed: 0, elapsed: 0s')
sys.stdout.flush()
self.file.write('completed: 0, elapsed: 0s')
self.file.flush()
self.timer = Timer()
def update(self):
......@@ -49,18 +50,18 @@ class ProgressBar(object):
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (self.bar_width - mark_width)
sys.stdout.write(
self.file.write(
'\r[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s'.format(
bar_chars, self.completed, self.task_num, fps,
int(elapsed + 0.5), eta))
else:
sys.stdout.write(
self.file.write(
'completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush()
self.file.flush()
def track_progress(func, tasks, bar_width=50, **kwargs):
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
"""Track the progress of tasks execution with a progress bar.
Tasks are done with a simple for-loop.
......@@ -85,12 +86,12 @@ def track_progress(func, tasks, bar_width=50, **kwargs):
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
prog_bar = ProgressBar(task_num, bar_width)
prog_bar = ProgressBar(task_num, bar_width, file=file)
results = []
for task in tasks:
results.append(func(task, **kwargs))
prog_bar.update()
sys.stdout.write('\n')
prog_bar.file.write('\n')
return results
......@@ -113,7 +114,8 @@ def track_parallel_progress(func,
bar_width=50,
chunksize=1,
skip_first=False,
keep_order=True):
keep_order=True,
file=sys.stdout):
"""Track the progress of parallel task execution with a progress bar.
The built-in :mod:`multiprocessing` module is used for process pools and
......@@ -153,7 +155,7 @@ def track_parallel_progress(func,
pool = init_pool(nproc, initializer, initargs)
start = not skip_first
task_num -= nproc * chunksize * int(skip_first)
prog_bar = ProgressBar(task_num, bar_width, start)
prog_bar = ProgressBar(task_num, bar_width, start, file=file)
results = []
if keep_order:
gen = pool.imap(func, tasks, chunksize)
......@@ -168,13 +170,13 @@ def track_parallel_progress(func,
prog_bar.start()
continue
prog_bar.update()
sys.stdout.write('\n')
prog_bar.file.write('\n')
pool.close()
pool.join()
return results
def track_iter_progress(tasks, bar_width=50, **kwargs):
def track_iter_progress(tasks, bar_width=50, file=sys.stdout, **kwargs):
"""Track the progress of tasks iteration or enumeration with a progress bar.
Tasks are yielded with a simple for-loop.
......@@ -198,8 +200,8 @@ def track_iter_progress(tasks, bar_width=50, **kwargs):
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
prog_bar = ProgressBar(task_num, bar_width)
prog_bar = ProgressBar(task_num, bar_width, file=file)
for task in tasks:
yield task
prog_bar.update()
sys.stdout.write('\n')
prog_bar.file.write('\n')
import sys
import time
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import pytest
import mmcv
def reset_string_io(io):
io.truncate(0)
io.seek(0)
if sys.version_info[0] == 2:
pytest.skip('skipping tests for python 2', allow_module_level=True)
class TestProgressBar(object):
def test_start(self, capsys):
def test_start(self):
out = StringIO()
bar_width = 20
# without total task num
prog_bar = mmcv.ProgressBar(bar_width=bar_width)
out, _ = capsys.readouterr()
assert out == 'completed: 0, elapsed: 0s'
prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False)
out, _ = capsys.readouterr()
assert out == ''
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
assert out.getvalue() == 'completed: 0, elapsed: 0s'
reset_string_io(out)
prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
assert out.getvalue() == ''
reset_string_io(out)
prog_bar.start()
out, _ = capsys.readouterr()
assert out == 'completed: 0, elapsed: 0s'
assert out.getvalue() == 'completed: 0, elapsed: 0s'
# with total task num
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width)
out, _ = capsys.readouterr()
assert out == '[{}] 0/10, elapsed: 0s, ETA:'.format(' ' * bar_width)
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, start=False)
out, _ = capsys.readouterr()
assert out == ''
reset_string_io(out)
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
assert out.getvalue() == '[{}] 0/10, elapsed: 0s, ETA:'.format(
' ' * bar_width)
reset_string_io(out)
prog_bar = mmcv.ProgressBar(
10, bar_width=bar_width, start=False, file=out)
assert out.getvalue() == ''
reset_string_io(out)
prog_bar.start()
out, _ = capsys.readouterr()
assert out == '[{}] 0/10, elapsed: 0s, ETA:'.format(' ' * bar_width)
assert out.getvalue() == '[{}] 0/10, elapsed: 0s, ETA:'.format(
' ' * bar_width)
def test_update(self, capsys):
def test_update(self):
out = StringIO()
bar_width = 20
# without total task num
prog_bar = mmcv.ProgressBar(bar_width=bar_width)
capsys.readouterr()
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
out, _ = capsys.readouterr()
assert out == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
reset_string_io(out)
# with total task num
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width)
capsys.readouterr()
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
out, _ = capsys.readouterr()
assert out == ('\r[{}] 1/10, 1.0 task/s, elapsed: 1s, ETA: 9s'.
format('>' * 2 + ' ' * 18))
assert out.getvalue() == ('\r[{}] 1/10, 1.0 task/s, '
'elapsed: 1s, ETA: 9s'.format('>' * 2 +
' ' * 18))
def sleep_1s(num):
......@@ -58,77 +72,80 @@ def sleep_1s(num):
return num
def test_track_progress_list(capsys):
ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3)
out, _ = capsys.readouterr()
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
def test_track_progress_list():
out = StringIO()
ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_progress_iterator(capsys):
def test_track_progress_iterator():
out = StringIO()
ret = mmcv.track_progress(
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3)
out, _ = capsys.readouterr()
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_iter_progress(capsys):
def test_track_iter_progress():
out = StringIO()
ret = []
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3):
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
ret.append(sleep_1s(num))
out, _ = capsys.readouterr()
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_enum_progress(capsys):
def test_track_enum_progress():
out = StringIO()
ret = []
count = []
for i, num in enumerate(mmcv.track_iter_progress([1, 2, 3], bar_width=3)):
for i, num in enumerate(
mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
ret.append(sleep_1s(num))
count.append(i)
out, _ = capsys.readouterr()
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
assert count == [0, 1, 2]
def test_track_parallel_progress_list(capsys):
def test_track_parallel_progress_list():
out = StringIO()
results = mmcv.track_parallel_progress(
sleep_1s, [1, 2, 3, 4], 2, bar_width=4)
out, _ = capsys.readouterr()
assert out == ('[ ] 0/4, elapsed: 0s, ETA:'
'\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
'\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
'\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
assert out.getvalue() == (
'[ ] 0/4, elapsed: 0s, ETA:'
'\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
'\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
'\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
assert results == [1, 2, 3, 4]
def test_track_parallel_progress_iterator(capsys):
def test_track_parallel_progress_iterator():
out = StringIO()
results = mmcv.track_parallel_progress(
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4)
out, _ = capsys.readouterr()
assert out == ('[ ] 0/4, elapsed: 0s, ETA:'
'\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
'\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
'\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
assert out.getvalue() == (
'[ ] 0/4, elapsed: 0s, ETA:'
'\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
'\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
'\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
assert results == [1, 2, 3, 4]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment