Unverified Commit d25a4ddf authored by Krishna Kalyan's avatar Krishna Kalyan Committed by GitHub
Browse files

Using Path and glob instead of walk_files (#1069)



- yesno
- librispeech
- libritts
- speechcommands
Co-authored-by: default avatarkrishnakalyan3 <skalyan@cloudera.com>
Co-authored-by: default avatarVincent Quenneville-Belair <vincentqb@gmail.com>
parent 79c97fb0
...@@ -8,7 +8,6 @@ from torch.utils.data import Dataset ...@@ -8,7 +8,6 @@ from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
walk_files,
) )
URL = "train-clean-100" URL = "train-clean-100"
...@@ -125,10 +124,7 @@ class LIBRISPEECH(Dataset): ...@@ -125,10 +124,7 @@ class LIBRISPEECH(Dataset):
download_url(url, root, hash_value=checksum) download_url(url, root, hash_value=checksum)
extract_archive(archive) extract_archive(archive)
walker = walk_files( self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio))
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
......
...@@ -8,7 +8,6 @@ from torch.utils.data import Dataset ...@@ -8,7 +8,6 @@ from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
walk_files,
) )
URL = "train-clean-100" URL = "train-clean-100"
...@@ -126,10 +125,7 @@ class LIBRITTS(Dataset): ...@@ -126,10 +125,7 @@ class LIBRITTS(Dataset):
download_url(url, root, hash_value=checksum) download_url(url, root, hash_value=checksum)
extract_archive(archive) extract_archive(archive)
walker = walk_files( self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio))
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
......
...@@ -8,7 +8,6 @@ from torch import Tensor ...@@ -8,7 +8,6 @@ from torch import Tensor
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
walk_files
) )
FOLDER_IN_ARCHIVE = "SpeechCommands" FOLDER_IN_ARCHIVE = "SpeechCommands"
...@@ -110,7 +109,7 @@ class SPEECHCOMMANDS(Dataset): ...@@ -110,7 +109,7 @@ class SPEECHCOMMANDS(Dataset):
self._walker = _load_list(self._path, "testing_list.txt") self._walker = _load_list(self._path, "testing_list.txt")
elif subset == "training": elif subset == "training":
excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt")) excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
walker = walk_files(self._path, suffix=".wav", prefix=True) walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
self._walker = [ self._walker = [
w for w in walker w for w in walker
if HASH_DIVIDER in w if HASH_DIVIDER in w
...@@ -118,7 +117,7 @@ class SPEECHCOMMANDS(Dataset): ...@@ -118,7 +117,7 @@ class SPEECHCOMMANDS(Dataset):
and os.path.normpath(w) not in excludes and os.path.normpath(w) not in excludes
] ]
else: else:
walker = walk_files(self._path, suffix=".wav", prefix=True) walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
......
...@@ -9,7 +9,6 @@ from torch.utils.data import Dataset ...@@ -9,7 +9,6 @@ from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
walk_files
) )
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz" URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
...@@ -85,10 +84,7 @@ class YESNO(Dataset): ...@@ -85,10 +84,7 @@ class YESNO(Dataset):
"Dataset not found. Please use `download=True` to download it." "Dataset not found. Please use `download=True` to download it."
) )
walker = walk_files( self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*' + self._ext_audio))
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment