Unverified Commit 47eb1e6a authored by Emmanouil Theofanis Chourdakis's avatar Emmanouil Theofanis Chourdakis Committed by GitHub
Browse files

Changed GTZAN so that it only traverses filenames belonging to the dataset (#791)

* Addressed review issues in PR #668

* Changed GTZAN so that it only traverses filenames belonging to the dataset

Now, instead of walking the whole directory and subdirectories of the dataset
GTZAN only looks for files under a `genre`/`genre`.`5 digit number`.wav format, where `genre` is an allowed GTZAN genre label.
This allows moving or removing files from the dataset (e.g. for fixing duplication or mislabeling issues).
parent 102174e9
...@@ -8,7 +8,6 @@ from torch.utils.data import Dataset ...@@ -8,7 +8,6 @@ from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
walk_files,
) )
# The following lists prefixed with `filtered_` provide a filtered split # The following lists prefixed with `filtered_` provide a filtered split
...@@ -22,6 +21,19 @@ from torchaudio.datasets.utils import ( ...@@ -22,6 +21,19 @@ from torchaudio.datasets.utils import (
# Those are used when GTZAN is initialised with the `filtered` keyword. # Those are used when GTZAN is initialised with the `filtered` keyword.
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning. # The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
gtzan_genres = [
"blues",
"classical",
"country",
"disco",
"hiphop",
"jazz",
"metal",
"pop",
"reggae",
"rock",
]
filtered_test = [ filtered_test = [
"blues.00012", "blues.00012",
"blues.00013", "blues.00013",
...@@ -964,7 +976,9 @@ filtered_valid = [ ...@@ -964,7 +976,9 @@ filtered_valid = [
URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz" URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
FOLDER_IN_ARCHIVE = "genres" FOLDER_IN_ARCHIVE = "genres"
_CHECKSUMS = {"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"} _CHECKSUMS = {
"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"
}
def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]: def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
...@@ -1032,10 +1046,32 @@ class GTZAN(Dataset): ...@@ -1032,10 +1046,32 @@ class GTZAN(Dataset):
) )
if self.subset is None: if self.subset is None:
walker = walk_files( # Check every subdirectory under dataset root
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True # which has the same name as the genres in
) # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
self._walker = list(walker) # This lets users remove or move around song files,
# useful when e.g. they want to use only some of the files
# in a genre or want to label other files with a different
# genre.
self._walker = []
root = os.path.expanduser(self._path)
for directory in gtzan_genres:
fulldir = os.path.join(root, directory)
if not os.path.exists(fulldir):
continue
songs_in_genre = os.listdir(fulldir)
for fname in songs_in_genre:
name, ext = os.path.splitext(fname)
if ext.lower() == ".wav" and "." in name:
# Check whether the file is of the form
# `gtzan_genre`.`5 digit number`.wav
genre, num = name.split(".")
if genre in gtzan_genres and len(num) == 5 and num.isdigit():
self._walker.append(name)
else: else:
if self.subset == "training": if self.subset == "training":
self._walker = filtered_train self._walker = filtered_train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment