Unverified Commit e8489a7b authored by Ezra-Yu's avatar Ezra-Yu Committed by GitHub
Browse files

Add case case_sensitive in scandir (#1389)

* add case_insensitive

* rename v

* case_insensitive to case_sensitive

* Update docstring
parent c85c240f
...@@ -36,7 +36,7 @@ def symlink(src, dst, overwrite=True, **kwargs): ...@@ -36,7 +36,7 @@ def symlink(src, dst, overwrite=True, **kwargs):
os.symlink(src, dst, **kwargs) os.symlink(src, dst, **kwargs)
def scandir(dir_path, suffix=None, recursive=False): def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files. """Scan a directory to find the interested files.
Args: Args:
...@@ -45,6 +45,8 @@ def scandir(dir_path, suffix=None, recursive=False): ...@@ -45,6 +45,8 @@ def scandir(dir_path, suffix=None, recursive=False):
interested in. Default: None. interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the recursive (bool, optional): If set to True, recursively scan the
directory. Default: False. directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns: Returns:
A generator for all the interested files with relative paths. A generator for all the interested files with relative paths.
...@@ -57,20 +59,25 @@ def scandir(dir_path, suffix=None, recursive=False): ...@@ -57,20 +59,25 @@ def scandir(dir_path, suffix=None, recursive=False):
if (suffix is not None) and not isinstance(suffix, (str, tuple)): if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings') raise TypeError('"suffix" must be a string or tuple of strings')
if suffix is not None and not case_sensitive:
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
item.lower() for item in suffix)
root = dir_path root = dir_path
def _scandir(dir_path, suffix, recursive): def _scandir(dir_path, suffix, recursive, case_sensitive):
for entry in os.scandir(dir_path): for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file(): if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root) rel_path = osp.relpath(entry.path, root)
if suffix is None or rel_path.endswith(suffix): _rel_path = rel_path if case_sensitive else rel_path.lower()
if suffix is None or _rel_path.endswith(suffix):
yield rel_path yield rel_path
elif recursive and os.path.isdir(entry.path): elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory # scan recursively if entry.path is a directory
yield from _scandir( yield from _scandir(entry.path, suffix, recursive,
entry.path, suffix=suffix, recursive=recursive) case_sensitive)
return _scandir(dir_path, suffix=suffix, recursive=recursive) return _scandir(dir_path, suffix, recursive, case_sensitive)
def find_vcs_root(path, markers=('.git', )): def find_vcs_root(path, markers=('.git', )):
......
...@@ -27,7 +27,7 @@ def test_check_file_exist(): ...@@ -27,7 +27,7 @@ def test_check_file_exist():
def test_scandir(): def test_scandir():
folder = osp.join(osp.dirname(osp.dirname(__file__)), 'data/for_scan') folder = osp.join(osp.dirname(osp.dirname(__file__)), 'data/for_scan')
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json'] filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT']
assert set(mmcv.scandir(folder)) == set(filenames) assert set(mmcv.scandir(folder)) == set(filenames)
assert set(mmcv.scandir(Path(folder))) == set(filenames) assert set(mmcv.scandir(Path(folder))) == set(filenames)
assert set(mmcv.scandir(folder, '.txt')) == set( assert set(mmcv.scandir(folder, '.txt')) == set(
...@@ -41,7 +41,7 @@ def test_scandir(): ...@@ -41,7 +41,7 @@ def test_scandir():
# path of sep is `\\` in windows but `/` in linux, so osp.join should be # path of sep is `\\` in windows but `/` in linux, so osp.join should be
# used to join string for compatibility # used to join string for compatibility
filenames_recursive = [ filenames_recursive = [
'a.bin', '1.txt', '2.txt', '1.json', '2.json', 'a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT',
osp.join('sub', '1.json'), osp.join('sub', '1.json'),
osp.join('sub', '1.txt'), '.file' osp.join('sub', '1.txt'), '.file'
] ]
...@@ -54,6 +54,19 @@ def test_scandir(): ...@@ -54,6 +54,19 @@ def test_scandir():
filename for filename in filenames_recursive filename for filename in filenames_recursive
if filename.endswith('.txt') if filename.endswith('.txt')
]) ])
assert set(
mmcv.scandir(folder, '.TXT', recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
if filename.endswith(('.txt', '.TXT'))
])
assert set(
mmcv.scandir(
folder, ('.TXT', '.JSON'), recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
if filename.endswith(('.txt', '.json', '.TXT'))
])
with pytest.raises(TypeError): with pytest.raises(TypeError):
list(mmcv.scandir(123)) list(mmcv.scandir(123))
with pytest.raises(TypeError): with pytest.raises(TypeError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment