Unverified Commit 712651aa authored by lizz's avatar lizz Committed by GitHub
Browse files

Adaptive progress bar length (#174)



* Adaptive progress bar length
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* it works
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* format

* pass test

* :lipstick

* test

* Update test_progressbar.py

* 2.7

* sort import
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* try this
Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent 90a0f353
...@@ -4,10 +4,10 @@ repos: ...@@ -4,10 +4,10 @@ repos:
rev: v1.9.3 rev: v1.9.3
hooks: hooks:
- id: seed-isort-config - id: seed-isort-config
# - repo: https://github.com/pre-commit/mirrors-isort - repo: https://github.com/pre-commit/mirrors-isort
# rev: v4.3.21 rev: v4.3.21
# hooks: hooks:
# - id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.29.0 rev: v0.29.0
hooks: hooks:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# #
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
version_file = '../mmcv/version.py' version_file = '../mmcv/version.py'
......
...@@ -11,27 +11,20 @@ class ProgressBar(object): ...@@ -11,27 +11,20 @@ class ProgressBar(object):
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num self.task_num = task_num
max_bar_width = self._get_max_bar_width() self.bar_width = bar_width
self.bar_width = (
bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0 self.completed = 0
self.file = file self.file = file
if start: if start:
self.start() self.start()
def _get_max_bar_width(self): @property
def terminal_width(self):
if sys.version_info > (3, 3): if sys.version_info > (3, 3):
from shutil import get_terminal_size from shutil import get_terminal_size
else: else:
from backports.shutil_get_terminal_size import get_terminal_size from backports.shutil_get_terminal_size import get_terminal_size
terminal_width, _ = get_terminal_size() width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) return width
if max_bar_width < 10:
print('terminal width is too small ({}), please consider '
'widen the terminal for better progressbar '
'visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def start(self): def start(self):
if self.task_num > 0: if self.task_num > 0:
...@@ -52,12 +45,17 @@ class ProgressBar(object): ...@@ -52,12 +45,17 @@ class ProgressBar(object):
if self.task_num > 0: if self.task_num > 0:
percentage = self.completed / float(self.task_num) percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5) eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage) msg = '\r[{{}}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s' \
bar_chars = '>' * mark_width + ' ' * (self.bar_width - mark_width) ''.format(self.completed, self.task_num, fps,
self.file.write( int(elapsed + 0.5), eta)
'\r[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s'.format(
bar_chars, self.completed, self.task_num, fps, bar_width = min(self.bar_width,
int(elapsed + 0.5), eta)) int(self.terminal_width - len(msg)) + 2,
int(self.terminal_width * 0.6))
bar_width = max(2, bar_width)
mark_width = int(bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
self.file.write(msg.format(bar_chars))
else: else:
self.file.write( self.file.write(
'completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 'completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
......
...@@ -2,14 +2,13 @@ import platform ...@@ -2,14 +2,13 @@ import platform
import re import re
import sys import sys
from io import open # for Python 2 (identical to builtin in Python 3) from io import open # for Python 2 (identical to builtin in Python 3)
from setuptools import Extension, dist, find_packages, setup
from pkg_resources import DistributionNotFound, get_distribution from pkg_resources import DistributionNotFound, get_distribution
from setuptools import Extension, dist, find_packages, setup
dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1']) dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1'])
import numpy # noqa: E402 import numpy # NOQA: E402 # isort:skip
from Cython.Distutils import build_ext # noqa: E402 from Cython.Distutils import build_ext # NOQA: E402 # isort:skip
def choose_requirement(primary, secondary): def choose_requirement(primary, secondary):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os
import sys import sys
import time import time
try:
from unittest.mock import patch
except ImportError:
from mock import patch
try: try:
from StringIO import StringIO from StringIO import StringIO
except ImportError: except ImportError:
...@@ -68,6 +74,26 @@ class TestProgressBar(object): ...@@ -68,6 +74,26 @@ class TestProgressBar(object):
'elapsed: 1s, ETA: 9s'.format('>' * 2 + 'elapsed: 1s, ETA: 9s'.format('>' * 2 +
' ' * 18)) ' ' * 18))
def test_adaptive_length(self):
with patch.dict('os.environ', {'COLUMNS': '80'}):
out = StringIO()
bar_width = 20
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
assert len(out.getvalue()) == 66
os.environ['COLUMNS'] = '30'
reset_string_io(out)
prog_bar.update()
assert len(out.getvalue()) == 48
os.environ['COLUMNS'] = '60'
reset_string_io(out)
prog_bar.update()
assert len(out.getvalue()) == 60
def sleep_1s(num): def sleep_1s(num):
time.sleep(1) time.sleep(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment