Unverified Commit e8489a7b authored by Ezra-Yu's avatar Ezra-Yu Committed by GitHub
Browse files

Add case case_sensitive in scandir (#1389)

* add case_insensitive

* rename v

* case_insensitive to case_sensitive

* Update docstring
parent c85c240f
......@@ -36,7 +36,7 @@ def symlink(src, dst, overwrite=True, **kwargs):
os.symlink(src, dst, **kwargs)
def scandir(dir_path, suffix=None, recursive=False):
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files.
Args:
......@@ -45,6 +45,8 @@ def scandir(dir_path, suffix=None, recursive=False):
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns:
A generator for all the interested files with relative paths.
......@@ -57,20 +59,25 @@ def scandir(dir_path, suffix=None, recursive=False):
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
if suffix is not None and not case_sensitive:
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
item.lower() for item in suffix)
root = dir_path
def _scandir(dir_path, suffix, recursive):
def _scandir(dir_path, suffix, recursive, case_sensitive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root)
if suffix is None or rel_path.endswith(suffix):
_rel_path = rel_path if case_sensitive else rel_path.lower()
if suffix is None or _rel_path.endswith(suffix):
yield rel_path
elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory
yield from _scandir(
entry.path, suffix=suffix, recursive=recursive)
yield from _scandir(entry.path, suffix, recursive,
case_sensitive)
return _scandir(dir_path, suffix=suffix, recursive=recursive)
return _scandir(dir_path, suffix, recursive, case_sensitive)
def find_vcs_root(path, markers=('.git', )):
......
......@@ -27,7 +27,7 @@ def test_check_file_exist():
def test_scandir():
folder = osp.join(osp.dirname(osp.dirname(__file__)), 'data/for_scan')
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json']
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT']
assert set(mmcv.scandir(folder)) == set(filenames)
assert set(mmcv.scandir(Path(folder))) == set(filenames)
assert set(mmcv.scandir(folder, '.txt')) == set(
......@@ -41,7 +41,7 @@ def test_scandir():
# path of sep is `\\` in windows but `/` in linux, so osp.join should be
# used to join string for compatibility
filenames_recursive = [
'a.bin', '1.txt', '2.txt', '1.json', '2.json',
'a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT',
osp.join('sub', '1.json'),
osp.join('sub', '1.txt'), '.file'
]
......@@ -54,6 +54,19 @@ def test_scandir():
filename for filename in filenames_recursive
if filename.endswith('.txt')
])
assert set(
mmcv.scandir(folder, '.TXT', recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
if filename.endswith(('.txt', '.TXT'))
])
assert set(
mmcv.scandir(
folder, ('.TXT', '.JSON'), recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
if filename.endswith(('.txt', '.json', '.TXT'))
])
with pytest.raises(TypeError):
list(mmcv.scandir(123))
with pytest.raises(TypeError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment