Unverified Commit f5aced81 authored by Krishna Kalyan's avatar Krishna Kalyan Committed by GitHub
Browse files

Refactor YesNo dataset (#1127)


Co-authored-by: default avatarkrishnakalyan3 <skalyan@cloudera.com>
parent e43a8e76
...@@ -11,23 +11,14 @@ from torchaudio.datasets.utils import ( ...@@ -11,23 +11,14 @@ from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
FOLDER_IN_ARCHIVE = "waves_yesno"
_CHECKSUMS = {
"http://www.openslr.org/resources/1/waves_yesno.tar.gz":
"962ff6e904d2df1126132ecec6978786"
}
def load_yesno_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, int, List[int]]:
# Read label
labels = [int(c) for c in fileid.split("_")]
# Read wav _RELEASE_CONFIGS = {
file_audio = os.path.join(path, fileid + ext_audio) "release1": {
waveform, sample_rate = torchaudio.load(file_audio) "folder_in_archive": "waves_yesno",
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
return waveform, sample_rate, labels "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
}
}
class YESNO(Dataset): class YESNO(Dataset):
...@@ -43,25 +34,26 @@ class YESNO(Dataset): ...@@ -43,25 +34,26 @@ class YESNO(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``). Whether to download the dataset if it is not found at root path. (default: ``False``).
""" """
_ext_audio = ".wav" def __init__(
self,
def __init__(self, root: Union[str, Path],
root: Union[str, Path], url: str = _RELEASE_CONFIGS["release1"]["url"],
url: str = URL, folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
download: bool = False) -> None: ) -> None:
# Get string representation of 'root' in case Path object is passed self._parse_filesystem(root, url, folder_in_archive, download)
root = os.fspath(root)
def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
root = Path(root)
archive = os.path.basename(url) archive = os.path.basename(url)
archive = os.path.join(root, archive) archive = root / archive
self._path = os.path.join(root, folder_in_archive)
self._path = root / folder_in_archive
if download: if download:
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
if not os.path.isfile(archive): if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None) checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url(url, root, hash_value=checksum, hash_type="md5") download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive) extract_archive(archive)
...@@ -70,7 +62,13 @@ class YESNO(Dataset): ...@@ -70,7 +62,13 @@ class YESNO(Dataset):
"Dataset not found. Please use `download=True` to download it." "Dataset not found. Please use `download=True` to download it."
) )
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*' + self._ext_audio)) self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
def _load_item(self, fileid: str, path: str):
labels = [int(c) for c in fileid.split("_")]
file_audio = os.path.join(path, fileid + ".wav")
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, labels
def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -82,13 +80,8 @@ class YESNO(Dataset): ...@@ -82,13 +80,8 @@ class YESNO(Dataset):
tuple: ``(waveform, sample_rate, labels)`` tuple: ``(waveform, sample_rate, labels)``
""" """
fileid = self._walker[n] fileid = self._walker[n]
item = load_yesno_item(fileid, self._path, self._ext_audio) item = self._load_item(fileid, self._path)
return item
# TODO Upon deprecation, uncomment line below and remove following code
# return item
waveform, sample_rate, labels = item
return waveform, sample_rate, labels
def __len__(self) -> int: def __len__(self) -> int:
return len(self._walker) return len(self._walker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment