Commit b92a8a09 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Create musdb handler and tests (#2484)

Summary:
Create dataset handler and tests for new dataset. Manually tested and unit tested to test validity. Pre-commit ran for style checks.

Pull Request resolved: https://github.com/pytorch/audio/pull/2484

Reviewed By: carolineechen, nateanl

Differential Revision: D37250556

Pulled By: skim0514

fbshipit-source-id: d2c8d73d22fd9d7282026265676f3eab1e178d51
parent 66a67d2e
......@@ -144,6 +144,14 @@ FluentSpeechCommands
:special-members: __getitem__
MUSDB_HQ
~~~~~~~~
.. autoclass:: MUSDB_HQ
:members:
:special-members: __getitem__
References
~~~~~~~~~~
......
......@@ -359,6 +359,14 @@
title="YesNo",
url="http://www.openslr.org/1/"
}
@misc{MUSDB18HQ,
author = {Rafii, Zafar and Liutkus, Antoine and Fabian-Robert St{\"o}ter and Mimilakis, Stylianos Ioannis and
Bittner, Rachel},
title = {{MUSDB18-HQ} - an uncompressed version of MUSDB18},
month = dec,
year = 2019,
doi = {10.5281/zenodo.3338373},
url = {https://doi.org/10.5281/zenodo.3338373}
@inproceedings{fluent,
author = {Loren Lugosch and Mirco Ravanelli and Patrick Ignoto and Vikrant Singh Tomar and Yoshua Bengio},
editor = {Gernot Kubin and Zdravko Kacic},
......
import os
import torch
from parameterized import parameterized
from torchaudio.datasets import musdb_hq
from torchaudio.datasets.musdb_hq import _VALIDATION_SET
from torchaudio_unittest.common_utils import (
get_whitenoise,
save_wav,
TempDirMixin,
TorchaudioTestCase,
)
_SOURCE_SETS = [
(None,),
(["bass", "drums", "other", "vocals"],),
(["bass", "drums", "other"],),
(["bass", "drums", "vocals"],),
(["bass", "vocals", "other"],),
(["vocals", "drums", "other"],),
(["mixture"],),
]
seed_dict = {
"bass": 0,
"drums": 1,
"other": 2,
"mixture": 3,
"vocals": 4,
}
EXT = ".wav"
def _save_sample(dataset_dir, folder, song, source, sample_rate, seed):
# create and save audio samples to corresponding files
path = os.path.join(dataset_dir, folder)
os.makedirs(path, exist_ok=True)
song_path = os.path.join(path, str(song))
os.makedirs(song_path, exist_ok=True)
source_path = os.path.join(song_path, f"{source}{EXT}")
data = get_whitenoise(
sample_rate=sample_rate,
duration=5,
n_channels=2,
seed=seed,
)
save_wav(source_path, data, sample_rate)
sample = (data, sample_rate, 5 * sample_rate, song)
return sample
def _get_mocked_samples(dataset_dir, sample_rate):
sample_count = 5
all_samples = {"train": {}, "test": {}}
folders = ["train", "test"]
sources = ["bass", "drums", "other", "mixture", "vocals"]
curr_idx = 0
for folder in folders:
for _ in range(sample_count):
sample_list = []
for source in sources:
sample = _save_sample(dataset_dir, folder, str(curr_idx), source, sample_rate, seed_dict.get(source))
sample_list.append(sample)
all_samples[folder][str(curr_idx)] = sample_list
curr_idx += 1
if folder == "train":
for name in _VALIDATION_SET:
sample_list = []
for source in sources:
sample = _save_sample(dataset_dir, folder, name, source, sample_rate, seed_dict.get(source))
sample_list.append(sample)
all_samples[folder][name] = sample_list
return all_samples
def get_mock_dataset(dataset_dir):
"""
dataset_dir: directory to the mocked dataset
"""
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 44100
return _get_mocked_samples(dataset_dir, sample_rate)
class TestMusDB_HQ(TempDirMixin, TorchaudioTestCase):
root_dir = None
backend = "default"
train_all_samples = {}
train_only_samples = {}
validation_samples = {}
test_samples = {}
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(cls.root_dir, "musdb18hq")
full_dataset = get_mock_dataset(dataset_dir)
cls.train_all_samples = full_dataset["train"]
cls.test_samples = full_dataset["test"]
for key in cls.train_all_samples:
if key in _VALIDATION_SET:
cls.validation_samples[key] = cls.train_all_samples[key]
else:
cls.train_only_samples[key] = cls.train_all_samples[key]
def _test_musdb_hq(self, dataset, data_samples, sources):
num_samples = 0
for _, (data, sample_rate, num_frames, name) in enumerate(dataset):
self.assertEqual(data, self.extractSources(data_samples[name], sources))
assert sample_rate == data_samples[name][0][1]
assert num_frames == data_samples[name][0][2]
assert name == data_samples[name][0][3]
num_samples += 1
assert num_samples == len(data_samples)
@parameterized.expand(_SOURCE_SETS)
def testMusDBSources_train_all(self, sources):
dataset = musdb_hq.MUSDB_HQ(self.root_dir, sources=sources, subset="train")
self._test_musdb_hq(dataset, self.train_all_samples, sources)
@parameterized.expand(_SOURCE_SETS)
def testMusDBSources_train_with_validation(self, sources):
dataset = musdb_hq.MUSDB_HQ(
self.root_dir,
sources=sources,
subset="train",
split="train",
)
self._test_musdb_hq(dataset, self.train_only_samples, sources)
@parameterized.expand(_SOURCE_SETS)
def testMusDBSources_validation(self, sources):
dataset = musdb_hq.MUSDB_HQ(
self.root_dir,
sources=sources,
subset="train",
split="validation",
)
self._test_musdb_hq(dataset, self.validation_samples, sources)
@parameterized.expand(_SOURCE_SETS)
def testMusDBSources_test(self, sources):
dataset = musdb_hq.MUSDB_HQ(
self.root_dir,
sources=sources,
subset="test",
)
self._test_musdb_hq(dataset, self.test_samples, sources)
def extractSources(self, samples, sources):
sources = ["bass", "drums", "other", "vocals"] if not sources else sources
return torch.stack([samples[seed_dict[source]][0] for source in sources])
......@@ -9,6 +9,7 @@ from .librimix import LibriMix
from .librispeech import LIBRISPEECH
from .libritts import LIBRITTS
from .ljspeech import LJSPEECH
from .musdb_hq import MUSDB_HQ
from .quesst14 import QUESST14
from .speechcommands import SPEECHCOMMANDS
from .tedlium import TEDLIUM
......@@ -32,5 +33,6 @@ __all__ = [
"LIBRITTS",
"TEDLIUM",
"QUESST14",
"MUSDB_HQ",
"FluentSpeechCommands",
]
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
import torchaudio
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
_EXT = ".wav"
_SAMPLE_RATE = 44100
_VALIDATION_SET = [
"Actions - One Minute Smile",
"Clara Berry And Wooldog - Waltz For My Victims",
"Johnny Lokke - Promises & Lies",
"Patrick Talbot - A Reason To Leave",
"Triviul - Angelsaint",
"Alexander Ross - Goodbye Bolero",
"Fergessen - Nos Palpitants",
"Leaf - Summerghost",
"Skelpolu - Human Mistakes",
"Young Griffo - Pennies",
"ANiMAL - Rockshow",
"James May - On The Line",
"Meaxic - Take A Step",
"Traffic Experiment - Sirens",
]
class MUSDB_HQ(Dataset):
"""Create *MUSDB_HQ* [:footcite:`MUSDB18HQ`] Dataset
Args:
root (str or Path): Root directory where the dataset's top level directory is found
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``].
sources (List[str] or None, optional): Sources extract data from.
List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
If ``None``, dataset consists of tracks except mixture.
(default: ``None``)
split (str or None, optional): Whether to split training set into train and validation set.
If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set.
(default: ``None``)
download (bool, optional): Whether to download the dataset if it is not found at root path.
(default: ``False``)
"""
def __init__(
self,
root: Union[str, Path],
subset: str,
sources: Optional[List[str]] = None,
split: Optional[str] = None,
download: bool = False,
) -> None:
self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources
self.split = split
basename = os.path.basename(_URL)
archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0]
assert subset in ["test", "train"], "`subset` must be one of ['test', 'train']"
assert self.split is None or self.split in [
"train",
"validation",
], "`split` must be one of ['train', 'validation']"
base_path = os.path.join(root, basename)
self._path = os.path.join(base_path, subset)
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
if not download:
raise RuntimeError("Dataset not found. Please use `download=True` to download")
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
os.makedirs(base_path, exist_ok=True)
extract_archive(archive, base_path)
self.names = self._collect_songs()
def _get_track(self, name, source):
return Path(self._path) / name / f"{source}{_EXT}"
def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
name = self.names[n]
wavs = []
num_frames = None
for source in self.sources:
track = self._get_track(name, source)
wav, sr = torchaudio.load(str(track))
assert sr == _SAMPLE_RATE, f"expected sample rate {_SAMPLE_RATE}, but got {sr}"
if num_frames is None:
num_frames = wav.shape[-1]
else:
assert wav.shape[-1] == num_frames, "num_frames do not match across sources"
wavs.append(wav)
stacked = torch.stack(wavs)
return stacked, _SAMPLE_RATE, num_frames, name
def _collect_songs(self):
if self.split == "validation":
return _VALIDATION_SET
path = Path(self._path)
names = []
for root, folders, _ in os.walk(path, followlinks=True):
root = Path(root)
if root.name.startswith(".") or folders or root == path:
continue
name = str(root.relative_to(path))
if self.split and name in _VALIDATION_SET:
continue
names.append(name)
return sorted(names)
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, int, str): ``(waveforms, sample_rate, num_frames, track_name)``
"""
return self._load_sample(n)
def __len__(self) -> int:
return len(self.names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment